diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/README.md b/DGLPyTorch/DrugDiscovery/SE3Transformer/README.md index 34b11ac7..f856d1ad 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/README.md +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/README.md @@ -328,7 +328,7 @@ The complete list of the available parameters for the `training.py` script conta - `--gradient_clip`: Clipping of the gradient norms (default: `None`) - `--accumulate_grad_batches`: Gradient accumulation (default: `1`) - `--ckpt_interval`: Save a checkpoint every N epochs (default: `-1`) -- `--eval_interval`: Do an evaluation round every N epochs (default: `1`) +- `--eval_interval`: Do an evaluation round every N epochs (default: `20`) - `--silent`: Minimize stdout output (default: `false`) **Paths** @@ -485,6 +485,7 @@ Our results were obtained by running the `scripts/train.sh` training script in t | 8 | 240 | 0.03380 | 0.03495 | 29min | 20min | 1.45x | + #### Training performance results ##### Training performance: NVIDIA DGX A100 (8x A100 80GB) @@ -495,8 +496,8 @@ Our results were obtained by running the `scripts/benchmark_train.sh` and `scrip |:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:| | 1 | 240 | 2.21 | 2.92 | 1.32x | | | | 1 | 120 | 1.81 | 2.04 | 1.13x | | | -| 8 | 240 | 17.15 | 22.95 | 1.34x | 7.76 | 7.86 | -| 8 | 120 | 13.89 | 15.62 | 1.12x | 7.67 | 7.66 | +| 8 | 240 | 15.88 | 21.02 | 1.32x | 7.18 | 7.20 | +| 8 | 120 | 12.68 | 13.99 | 1.10x | 7.00 | 6.86 | To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide). @@ -510,8 +511,8 @@ Our results were obtained by running the `scripts/benchmark_train.sh` and `scrip |:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:| | 1 | 240 | 1.25 | 1.88 | 1.50x | | | | 1 | 120 | 1.03 | 1.41 | 1.37x | | | -| 8 | 240 | 9.33 | 14.02 | 1.50x | 7.46 | 7.46 | -| 8 | 120 | 7.39 | 9.41 | 1.27x | 7.17 | 6.67 | +| 8 | 240 | 8.68 | 12.75 | 1.47x | 6.94 | 6.78 | +| 8 | 120 | 6.64 | 8.58 | 1.29x | 6.44 | 6.08 | To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide). @@ -572,6 +573,15 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic ### Changelog +October 2021: +- Updated README performance tables +- Fixed shape mismatch when using partially fused TFNs per output degree +- Fixed shape mismatch when using partially fused TFNs per input degree with edge degrees > 0 + +September 2021: +- Moved to new location (from `PyTorch/DrugDiscovery` to `DGLPyTorch/DrugDiscovery`) +- Fixed multi-GPUs training script + August 2021 - Initial release diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/__init__.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/__init__.py index 628d01e9..6d796b2d 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/__init__.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/__init__.py @@ -1,2 +1,2 @@ -from .transformer import SE3Transformer, SE3TransformerPooled -from .fiber import Fiber +from .transformer import SE3Transformer, SE3TransformerPooled +from .fiber import Fiber diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/basis.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/basis.py index 74f04a0f..44fa4349 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/basis.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/basis.py @@ -54,9 +54,8 @@ def get_all_clebsch_gordon(max_degree: int, device) -> List[List[Tensor]]: def get_spherical_harmonics(relative_pos: Tensor, max_degree: int) -> List[Tensor]: all_degrees = list(range(2 * max_degree + 1)) - with nvtx_range('spherical harmonics'): - sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True) - return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1) + sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True) + return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1) @torch.jit.script diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/fiber.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/fiber.py index 38db33b0..50d5c6da 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/fiber.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/fiber.py @@ -1,144 +1,144 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - - -from collections import namedtuple -from itertools import product -from typing import Dict - -import torch -from torch import Tensor - -from se3_transformer.runtime.utils import degree_to_dim - -FiberEl = namedtuple('FiberEl', ['degree', 'channels']) - - -class Fiber(dict): - """ - Describes the structure of some set of features. - Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1. - Type-0 features: invariant scalars - Type-1 features: equivariant 3D vectors - Type-2 features: equivariant symmetric traceless matrices - ... - - As inputs to a SE3 layer, there can be many features of the same types, and many features of different types. - The 'multiplicity' or 'number of channels' is the number of features of a given type. - This class puts together all the degrees and their multiplicities in order to describe - the inputs, outputs or hidden features of SE3 layers. - """ - - def __init__(self, structure): - if isinstance(structure, dict): - structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])] - elif not isinstance(structure[0], FiberEl): - structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1]))) - self.structure = structure - super().__init__({d: m for d, m in self.structure}) - - @property - def degrees(self): - return sorted([t.degree for t in self.structure]) - - @property - def channels(self): - return [self[d] for d in self.degrees] - - @property - def num_features(self): - """ Size of the resulting tensor if all features were concatenated together """ - return sum(t.channels * degree_to_dim(t.degree) for t in self.structure) - - @staticmethod - def create(num_degrees: int, num_channels: int): - """ Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """ - return Fiber([(degree, num_channels) for degree in range(num_degrees)]) - - @staticmethod - def from_features(feats: Dict[str, Tensor]): - """ Infer the Fiber structure from a feature dict """ - structure = {} - for k, v in feats.items(): - degree = int(k) - assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)' - assert v.shape[-1] == degree_to_dim(degree) - structure[degree] = v.shape[-2] - return Fiber(structure) - - def __getitem__(self, degree: int): - """ fiber[degree] returns the multiplicity for this degree """ - return dict(self.structure).get(degree, 0) - - def __iter__(self): - """ Iterate over namedtuples (degree, channels) """ - return iter(self.structure) - - def __mul__(self, other): - """ - If other in an int, multiplies all the multiplicities by other. - If other is a fiber, returns the cartesian product. - """ - if isinstance(other, Fiber): - return product(self.structure, other.structure) - elif isinstance(other, int): - return Fiber({t.degree: t.channels * other for t in self.structure}) - - def __add__(self, other): - """ - If other in an int, add other to all the multiplicities. - If other is a fiber, add the multiplicities of the fibers together. - """ - if isinstance(other, Fiber): - return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure}) - elif isinstance(other, int): - return Fiber({t.degree: t.channels + other for t in self.structure}) - - def __repr__(self): - return str(self.structure) - - @staticmethod - def combine_max(f1, f2): - """ Combine two fiber by taking the maximum multiplicity for each degree in both fibers """ - new_dict = dict(f1.structure) - for k, m in f2.structure: - new_dict[k] = max(new_dict.get(k, 0), m) - - return Fiber(list(new_dict.items())) - - @staticmethod - def combine_selectively(f1, f2): - """ Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """ - # only use orders which occur in fiber f1 - new_dict = dict(f1.structure) - for k in f1.degrees: - if k in f2.degrees: - new_dict[k] += f2[k] - return Fiber(list(new_dict.items())) - - def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int): - # dict(N, num_channels, 2d+1) -> (N, num_heads, -1) - fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in - self.degrees] - fibers = torch.cat(fibers, -1) - return fibers +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + + +from collections import namedtuple +from itertools import product +from typing import Dict + +import torch +from torch import Tensor + +from se3_transformer.runtime.utils import degree_to_dim + +FiberEl = namedtuple('FiberEl', ['degree', 'channels']) + + +class Fiber(dict): + """ + Describes the structure of some set of features. + Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1. + Type-0 features: invariant scalars + Type-1 features: equivariant 3D vectors + Type-2 features: equivariant symmetric traceless matrices + ... + + As inputs to a SE3 layer, there can be many features of the same types, and many features of different types. + The 'multiplicity' or 'number of channels' is the number of features of a given type. + This class puts together all the degrees and their multiplicities in order to describe + the inputs, outputs or hidden features of SE3 layers. + """ + + def __init__(self, structure): + if isinstance(structure, dict): + structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])] + elif not isinstance(structure[0], FiberEl): + structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1]))) + self.structure = structure + super().__init__({d: m for d, m in self.structure}) + + @property + def degrees(self): + return sorted([t.degree for t in self.structure]) + + @property + def channels(self): + return [self[d] for d in self.degrees] + + @property + def num_features(self): + """ Size of the resulting tensor if all features were concatenated together """ + return sum(t.channels * degree_to_dim(t.degree) for t in self.structure) + + @staticmethod + def create(num_degrees: int, num_channels: int): + """ Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """ + return Fiber([(degree, num_channels) for degree in range(num_degrees)]) + + @staticmethod + def from_features(feats: Dict[str, Tensor]): + """ Infer the Fiber structure from a feature dict """ + structure = {} + for k, v in feats.items(): + degree = int(k) + assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)' + assert v.shape[-1] == degree_to_dim(degree) + structure[degree] = v.shape[-2] + return Fiber(structure) + + def __getitem__(self, degree: int): + """ fiber[degree] returns the multiplicity for this degree """ + return dict(self.structure).get(degree, 0) + + def __iter__(self): + """ Iterate over namedtuples (degree, channels) """ + return iter(self.structure) + + def __mul__(self, other): + """ + If other in an int, multiplies all the multiplicities by other. + If other is a fiber, returns the cartesian product. + """ + if isinstance(other, Fiber): + return product(self.structure, other.structure) + elif isinstance(other, int): + return Fiber({t.degree: t.channels * other for t in self.structure}) + + def __add__(self, other): + """ + If other in an int, add other to all the multiplicities. + If other is a fiber, add the multiplicities of the fibers together. + """ + if isinstance(other, Fiber): + return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure}) + elif isinstance(other, int): + return Fiber({t.degree: t.channels + other for t in self.structure}) + + def __repr__(self): + return str(self.structure) + + @staticmethod + def combine_max(f1, f2): + """ Combine two fiber by taking the maximum multiplicity for each degree in both fibers """ + new_dict = dict(f1.structure) + for k, m in f2.structure: + new_dict[k] = max(new_dict.get(k, 0), m) + + return Fiber(list(new_dict.items())) + + @staticmethod + def combine_selectively(f1, f2): + """ Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """ + # only use orders which occur in fiber f1 + new_dict = dict(f1.structure) + for k in f1.degrees: + if k in f2.degrees: + new_dict[k] += f2[k] + return Fiber(list(new_dict.items())) + + def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int): + # dict(N, num_channels, 2d+1) -> (N, num_heads, -1) + fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in + self.degrees] + fibers = torch.cat(fibers, -1) + return fibers diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/__init__.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/__init__.py index 9eb9e3ce..2d6b14fb 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/__init__.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/__init__.py @@ -1,5 +1,5 @@ -from .linear import LinearSE3 -from .norm import NormSE3 -from .pooling import GPooling -from .convolution import ConvSE3 +from .linear import LinearSE3 +from .norm import NormSE3 +from .pooling import GPooling +from .convolution import ConvSE3 from .attention import AttentionBlockSE3 \ No newline at end of file diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/attention.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/attention.py index 90b1e456..11a4fb16 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/attention.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/attention.py @@ -1,182 +1,180 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - -import dgl -import numpy as np -import torch -import torch.nn as nn -from dgl import DGLGraph -from dgl.ops import edge_softmax -from torch import Tensor -from typing import Dict, Optional, Union - -from se3_transformer.model.fiber import Fiber -from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel -from se3_transformer.model.layers.linear import LinearSE3 -from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features -from torch.cuda.nvtx import range as nvtx_range - - -class AttentionSE3(nn.Module): - """ Multi-headed sparse graph self-attention (SE(3)-equivariant) """ - - def __init__( - self, - num_heads: int, - key_fiber: Fiber, - value_fiber: Fiber - ): - """ - :param num_heads: Number of attention heads - :param key_fiber: Fiber for the keys (and also for the queries) - :param value_fiber: Fiber for the values - """ - super().__init__() - self.num_heads = num_heads - self.key_fiber = key_fiber - self.value_fiber = value_fiber - - def forward( - self, - value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused) - key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused) - query: Dict[str, Tensor], # node features - graph: DGLGraph - ): - with nvtx_range('AttentionSE3'): - with nvtx_range('reshape keys and queries'): - if isinstance(key, Tensor): - # case where features of all types are fused - key = key.reshape(key.shape[0], self.num_heads, -1) - # need to reshape queries that way to keep the same layout as keys - out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1) - query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1) - else: - # features are not fused, need to fuse and reshape them - key = self.key_fiber.to_attention_heads(key, self.num_heads) - query = self.key_fiber.to_attention_heads(query, self.num_heads) - - with nvtx_range('attention dot product + softmax'): - # Compute attention weights (softmax of inner product between key and query) - edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1) - edge_weights = edge_weights / np.sqrt(self.key_fiber.num_features) - edge_weights = edge_softmax(graph, edge_weights) - edge_weights = edge_weights[..., None, None] - - with nvtx_range('weighted sum'): - if isinstance(value, Tensor): - # features of all types are fused - v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1]) - weights = edge_weights * v - feat_out = dgl.ops.copy_e_sum(graph, weights) - feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads - out = unfuse_features(feat_out, self.value_fiber.degrees) - else: - out = {} - for degree, channels in self.value_fiber: - v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads, - degree_to_dim(degree)) - weights = edge_weights * v - res = dgl.ops.copy_e_sum(graph, weights) - out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads - - return out - - -class AttentionBlockSE3(nn.Module): - """ Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """ - - def __init__( - self, - fiber_in: Fiber, - fiber_out: Fiber, - fiber_edge: Optional[Fiber] = None, - num_heads: int = 4, - channels_div: int = 2, - use_layer_norm: bool = False, - max_degree: bool = 4, - fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL, - **kwargs - ): - """ - :param fiber_in: Fiber describing the input features - :param fiber_out: Fiber describing the output features - :param fiber_edge: Fiber describing the edge features (node distances excluded) - :param num_heads: Number of attention heads - :param channels_div: Divide the channels by this integer for computing values - :param use_layer_norm: Apply layer normalization between MLP layers - :param max_degree: Maximum degree used in the bases computation - :param fuse_level: Maximum fuse level to use in TFN convolutions - """ - super().__init__() - if fiber_edge is None: - fiber_edge = Fiber({}) - self.fiber_in = fiber_in - # value_fiber has same structure as fiber_out but #channels divided by 'channels_div' - value_fiber = Fiber([(degree, channels // channels_div) for degree, channels in fiber_out]) - # key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber - # (queries are merely projected, hence degrees have to match input) - key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees]) - - self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge, - use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level, - allow_fused_output=True) - self.to_query = LinearSE3(fiber_in, key_query_fiber) - self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber) - self.project = LinearSE3(value_fiber + fiber_in, fiber_out) - - def forward( - self, - node_features: Dict[str, Tensor], - edge_features: Dict[str, Tensor], - graph: DGLGraph, - basis: Dict[str, Tensor] - ): - with nvtx_range('AttentionBlockSE3'): - with nvtx_range('keys / values'): - fused_key_value = self.to_key_value(node_features, edge_features, graph, basis) - key, value = self._get_key_value_from_fused(fused_key_value) - - with nvtx_range('queries'): - query = self.to_query(node_features) - - z = self.attention(value, key, query, graph) - z_concat = aggregate_residual(node_features, z, 'cat') - return self.project(z_concat) - - def _get_key_value_from_fused(self, fused_key_value): - # Extract keys and queries features from fused features - if isinstance(fused_key_value, Tensor): - # Previous layer was a fully fused convolution - value, key = torch.chunk(fused_key_value, chunks=2, dim=-2) - else: - key, value = {}, {} - for degree, feat in fused_key_value.items(): - if int(degree) in self.fiber_in.degrees: - value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2) - else: - value[degree] = feat - - return key, value - - +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + +import dgl +import numpy as np +import torch +import torch.nn as nn +from dgl import DGLGraph +from dgl.ops import edge_softmax +from torch import Tensor +from typing import Dict, Optional, Union + +from se3_transformer.model.fiber import Fiber +from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel +from se3_transformer.model.layers.linear import LinearSE3 +from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features +from torch.cuda.nvtx import range as nvtx_range + + +class AttentionSE3(nn.Module): + """ Multi-headed sparse graph self-attention (SE(3)-equivariant) """ + + def __init__( + self, + num_heads: int, + key_fiber: Fiber, + value_fiber: Fiber + ): + """ + :param num_heads: Number of attention heads + :param key_fiber: Fiber for the keys (and also for the queries) + :param value_fiber: Fiber for the values + """ + super().__init__() + self.num_heads = num_heads + self.key_fiber = key_fiber + self.value_fiber = value_fiber + + def forward( + self, + value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused) + key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused) + query: Dict[str, Tensor], # node features + graph: DGLGraph + ): + with nvtx_range('AttentionSE3'): + with nvtx_range('reshape keys and queries'): + if isinstance(key, Tensor): + # case where features of all types are fused + key = key.reshape(key.shape[0], self.num_heads, -1) + # need to reshape queries that way to keep the same layout as keys + out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1) + query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1) + else: + # features are not fused, need to fuse and reshape them + key = self.key_fiber.to_attention_heads(key, self.num_heads) + query = self.key_fiber.to_attention_heads(query, self.num_heads) + + with nvtx_range('attention dot product + softmax'): + # Compute attention weights (softmax of inner product between key and query) + edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1) + edge_weights = edge_weights / np.sqrt(self.key_fiber.num_features) + edge_weights = edge_softmax(graph, edge_weights) + edge_weights = edge_weights[..., None, None] + + with nvtx_range('weighted sum'): + if isinstance(value, Tensor): + # features of all types are fused + v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1]) + weights = edge_weights * v + feat_out = dgl.ops.copy_e_sum(graph, weights) + feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads + out = unfuse_features(feat_out, self.value_fiber.degrees) + else: + out = {} + for degree, channels in self.value_fiber: + v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads, + degree_to_dim(degree)) + weights = edge_weights * v + res = dgl.ops.copy_e_sum(graph, weights) + out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads + + return out + + +class AttentionBlockSE3(nn.Module): + """ Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """ + + def __init__( + self, + fiber_in: Fiber, + fiber_out: Fiber, + fiber_edge: Optional[Fiber] = None, + num_heads: int = 4, + channels_div: int = 2, + use_layer_norm: bool = False, + max_degree: bool = 4, + fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL, + **kwargs + ): + """ + :param fiber_in: Fiber describing the input features + :param fiber_out: Fiber describing the output features + :param fiber_edge: Fiber describing the edge features (node distances excluded) + :param num_heads: Number of attention heads + :param channels_div: Divide the channels by this integer for computing values + :param use_layer_norm: Apply layer normalization between MLP layers + :param max_degree: Maximum degree used in the bases computation + :param fuse_level: Maximum fuse level to use in TFN convolutions + """ + super().__init__() + if fiber_edge is None: + fiber_edge = Fiber({}) + self.fiber_in = fiber_in + # value_fiber has same structure as fiber_out but #channels divided by 'channels_div' + value_fiber = Fiber([(degree, channels // channels_div) for degree, channels in fiber_out]) + # key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber + # (queries are merely projected, hence degrees have to match input) + key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees]) + + self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge, + use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level, + allow_fused_output=True) + self.to_query = LinearSE3(fiber_in, key_query_fiber) + self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber) + self.project = LinearSE3(value_fiber + fiber_in, fiber_out) + + def forward( + self, + node_features: Dict[str, Tensor], + edge_features: Dict[str, Tensor], + graph: DGLGraph, + basis: Dict[str, Tensor] + ): + with nvtx_range('AttentionBlockSE3'): + with nvtx_range('keys / values'): + fused_key_value = self.to_key_value(node_features, edge_features, graph, basis) + key, value = self._get_key_value_from_fused(fused_key_value) + + with nvtx_range('queries'): + query = self.to_query(node_features) + + z = self.attention(value, key, query, graph) + z_concat = aggregate_residual(node_features, z, 'cat') + return self.project(z_concat) + + def _get_key_value_from_fused(self, fused_key_value): + # Extract keys and queries features from fused features + if isinstance(fused_key_value, Tensor): + # Previous layer was a fully fused convolution + value, key = torch.chunk(fused_key_value, chunks=2, dim=-2) + else: + key, value = {}, {} + for degree, feat in fused_key_value.items(): + if int(degree) in self.fiber_in.degrees: + value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2) + else: + value[degree] = feat + + return key, value diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/convolution.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/convolution.py index d58adc4f..0cab99e1 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/convolution.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/convolution.py @@ -1,334 +1,345 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - -from enum import Enum -from itertools import product -from typing import Dict - -import dgl -import numpy as np -import torch -import torch.nn as nn -from dgl import DGLGraph -from torch import Tensor -from torch.cuda.nvtx import range as nvtx_range - -from se3_transformer.model.fiber import Fiber -from se3_transformer.runtime.utils import degree_to_dim, unfuse_features - - -class ConvSE3FuseLevel(Enum): - """ - Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met. - If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered. - A higher level means faster training, but also more memory usage. - If you are tight on memory and want to feed large inputs to the network, choose a low value. - If you want to train fast, choose a high value. - Recommended value is FULL with AMP. - - Fully fused TFN convolutions requirements: - - all input channels are the same - - all output channels are the same - - input degrees span the range [0, ..., max_degree] - - output degrees span the range [0, ..., max_degree] - - Partially fused TFN convolutions requirements: - * For fusing by output degree: - - all input channels are the same - - input degrees span the range [0, ..., max_degree] - * For fusing by input degree: - - all output channels are the same - - output degrees span the range [0, ..., max_degree] - - Original TFN pairwise convolutions: no requirements - """ - - FULL = 2 - PARTIAL = 1 - NONE = 0 - - -class RadialProfile(nn.Module): - """ - Radial profile function. - Outputs weights used to weigh basis matrices in order to get convolution kernels. - In TFN notation: $R^{l,k}$ - In SE(3)-Transformer notation: $\phi^{l,k}$ - - Note: - In the original papers, this function only depends on relative node distances ||x||. - Here, we allow this function to also take as input additional invariant edge features. - This does not break equivariance and adds expressive power to the model. - - Diagram: - invariant edge features (node distances included) ───> MLP layer (shared across edges) ───> radial weights - """ - - def __init__( - self, - num_freq: int, - channels_in: int, - channels_out: int, - edge_dim: int = 1, - mid_dim: int = 32, - use_layer_norm: bool = False - ): - """ - :param num_freq: Number of frequencies - :param channels_in: Number of input channels - :param channels_out: Number of output channels - :param edge_dim: Number of invariant edge features (input to the radial function) - :param mid_dim: Size of the hidden MLP layers - :param use_layer_norm: Apply layer normalization between MLP layers - """ - super().__init__() - modules = [ - nn.Linear(edge_dim, mid_dim), - nn.LayerNorm(mid_dim) if use_layer_norm else None, - nn.ReLU(), - nn.Linear(mid_dim, mid_dim), - nn.LayerNorm(mid_dim) if use_layer_norm else None, - nn.ReLU(), - nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False) - ] - - self.net = nn.Sequential(*[m for m in modules if m is not None]) - - def forward(self, features: Tensor) -> Tensor: - return self.net(features) - - -class VersatileConvSE3(nn.Module): - """ - Building block for TFN convolutions. - This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions. - """ - - def __init__(self, - freq_sum: int, - channels_in: int, - channels_out: int, - edge_dim: int, - use_layer_norm: bool, - fuse_level: ConvSE3FuseLevel): - super().__init__() - self.freq_sum = freq_sum - self.channels_out = channels_out - self.channels_in = channels_in - self.fuse_level = fuse_level - self.radial_func = RadialProfile(num_freq=freq_sum, - channels_in=channels_in, - channels_out=channels_out, - edge_dim=edge_dim, - use_layer_norm=use_layer_norm) - - def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor): - with nvtx_range(f'VersatileConvSE3'): - num_edges = features.shape[0] - in_dim = features.shape[2] - with nvtx_range(f'RadialProfile'): - radial_weights = self.radial_func(invariant_edge_feats) \ - .view(-1, self.channels_out, self.channels_in * self.freq_sum) - - if basis is not None: - # This block performs the einsum n i l, n o i f, n l f k -> n o k - out_dim = basis.shape[-1] - if self.fuse_level != ConvSE3FuseLevel.FULL: - out_dim += out_dim % 2 - 1 # Account for padded basis - basis_view = basis.view(num_edges, in_dim, -1) - tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1]) - return (radial_weights @ tmp)[:, :, :out_dim] - else: - # k = l = 0 non-fused case - return radial_weights @ features - - -class ConvSE3(nn.Module): - """ - SE(3)-equivariant graph convolution (Tensor Field Network convolution). - This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance. - Features of different degrees interact together to produce output features. - - Note 1: - The option is given to not pool the output. This means that the convolution sum over neighbors will not be - done, and the returned features will be edge features instead of node features. - - Note 2: - Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0. - Input edge features are concatenated with input source node features before the kernel is applied. - """ - - def __init__( - self, - fiber_in: Fiber, - fiber_out: Fiber, - fiber_edge: Fiber, - pool: bool = True, - use_layer_norm: bool = False, - self_interaction: bool = False, - max_degree: int = 4, - fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL, - allow_fused_output: bool = False - ): - """ - :param fiber_in: Fiber describing the input features - :param fiber_out: Fiber describing the output features - :param fiber_edge: Fiber describing the edge features (node distances excluded) - :param pool: If True, compute final node features by averaging incoming edge features - :param use_layer_norm: Apply layer normalization between MLP layers - :param self_interaction: Apply self-interaction of nodes - :param max_degree: Maximum degree used in the bases computation - :param fuse_level: Maximum fuse level to use in TFN convolutions - :param allow_fused_output: Allow the module to output a fused representation of features - """ - super().__init__() - self.pool = pool - self.fiber_in = fiber_in - self.fiber_out = fiber_out - self.self_interaction = self_interaction - self.max_degree = max_degree - self.allow_fused_output = allow_fused_output - - # channels_in: account for the concatenation of edge features - channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in]) - channels_out_set = set([f.channels for f in self.fiber_out]) - unique_channels_in = (len(channels_in_set) == 1) - unique_channels_out = (len(channels_out_set) == 1) - degrees_up_to_max = list(range(max_degree + 1)) - common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm) - - if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \ - unique_channels_in and fiber_in.degrees == degrees_up_to_max and \ - unique_channels_out and fiber_out.degrees == degrees_up_to_max: - # Single fused convolution - self.used_fuse_level = ConvSE3FuseLevel.FULL - - sum_freq = sum([ - degree_to_dim(min(d_in, d_out)) - for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max) - ]) - - self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0], - fuse_level=self.used_fuse_level, **common_args) - - elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \ - unique_channels_in and fiber_in.degrees == degrees_up_to_max: - # Convolutions fused per output degree - self.used_fuse_level = ConvSE3FuseLevel.PARTIAL - self.conv_out = nn.ModuleDict() - for d_out, c_out in fiber_out: - sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees]) - self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out, - fuse_level=self.used_fuse_level, **common_args) - - elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \ - unique_channels_out and fiber_out.degrees == degrees_up_to_max: - # Convolutions fused per input degree - self.used_fuse_level = ConvSE3FuseLevel.PARTIAL - self.conv_in = nn.ModuleDict() - for d_in, c_in in fiber_in: - sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees]) - self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, c_in, list(channels_out_set)[0], - fuse_level=self.used_fuse_level, **common_args) - else: - # Use pairwise TFN convolutions - self.used_fuse_level = ConvSE3FuseLevel.NONE - self.conv = nn.ModuleDict() - for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out): - dict_key = f'{degree_in},{degree_out}' - channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0) - sum_freq = degree_to_dim(min(degree_in, degree_out)) - self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out, - fuse_level=self.used_fuse_level, **common_args) - - if self_interaction: - self.to_kernel_self = nn.ParameterDict() - for degree_out, channels_out in fiber_out: - if fiber_in[degree_out]: - self.to_kernel_self[str(degree_out)] = nn.Parameter( - torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out])) - - def forward( - self, - node_feats: Dict[str, Tensor], - edge_feats: Dict[str, Tensor], - graph: DGLGraph, - basis: Dict[str, Tensor] - ): - with nvtx_range(f'ConvSE3'): - invariant_edge_feats = edge_feats['0'].squeeze(-1) - src, dst = graph.edges() - out = {} - in_features = [] - - # Fetch all input features from edge and node features - for degree_in in self.fiber_in.degrees: - src_node_features = node_feats[str(degree_in)][src] - if degree_in > 0 and str(degree_in) in edge_feats: - # Handle edge features of any type by concatenating them to node features - src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1) - in_features.append(src_node_features) - - if self.used_fuse_level == ConvSE3FuseLevel.FULL: - in_features_fused = torch.cat(in_features, dim=-1) - out = self.conv(in_features_fused, invariant_edge_feats, basis['fully_fused']) - - if not self.allow_fused_output or self.self_interaction or self.pool: - out = unfuse_features(out, self.fiber_out.degrees) - - elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'): - in_features_fused = torch.cat(in_features, dim=-1) - for degree_out in self.fiber_out.degrees: - out[str(degree_out)] = self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats, basis[f'out{degree_out}_fused']) - - elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'): - out = 0 - for degree_in, feature in zip(self.fiber_in.degrees, in_features): - out = out + self.conv_in[str(degree_in)](feature, invariant_edge_feats, basis[f'in{degree_in}_fused']) - if not self.allow_fused_output or self.self_interaction or self.pool: - out = unfuse_features(out, self.fiber_out.degrees) - else: - # Fallback to pairwise TFN convolutions - for degree_out in self.fiber_out.degrees: - out_feature = 0 - for degree_in, feature in zip(self.fiber_in.degrees, in_features): - dict_key = f'{degree_in},{degree_out}' - out_feature = out_feature + self.conv[dict_key](feature, invariant_edge_feats, basis.get(dict_key, None)) - out[str(degree_out)] = out_feature - - for degree_out in self.fiber_out.degrees: - if self.self_interaction and str(degree_out) in self.to_kernel_self: - with nvtx_range(f'self interaction'): - dst_features = node_feats[str(degree_out)][dst] - kernel_self = self.to_kernel_self[str(degree_out)] - out[str(degree_out)] = out[str(degree_out)] + kernel_self @ dst_features - - if self.pool: - with nvtx_range(f'pooling'): - if isinstance(out, dict): - out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)]) - else: - out = dgl.ops.copy_e_sum(graph, out) - return out - - +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + +from enum import Enum +from itertools import product +from typing import Dict + +import dgl +import numpy as np +import torch +import torch.nn as nn +from dgl import DGLGraph +from torch import Tensor +from torch.cuda.nvtx import range as nvtx_range + +from se3_transformer.model.fiber import Fiber +from se3_transformer.runtime.utils import degree_to_dim, unfuse_features + + +class ConvSE3FuseLevel(Enum): + """ + Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met. + If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered. + A higher level means faster training, but also more memory usage. + If you are tight on memory and want to feed large inputs to the network, choose a low value. + If you want to train fast, choose a high value. + Recommended value is FULL with AMP. + + Fully fused TFN convolutions requirements: + - all input channels are the same + - all output channels are the same + - input degrees span the range [0, ..., max_degree] + - output degrees span the range [0, ..., max_degree] + + Partially fused TFN convolutions requirements: + * For fusing by output degree: + - all input channels are the same + - input degrees span the range [0, ..., max_degree] + * For fusing by input degree: + - all output channels are the same + - output degrees span the range [0, ..., max_degree] + + Original TFN pairwise convolutions: no requirements + """ + + FULL = 2 + PARTIAL = 1 + NONE = 0 + + +class RadialProfile(nn.Module): + """ + Radial profile function. + Outputs weights used to weigh basis matrices in order to get convolution kernels. + In TFN notation: $R^{l,k}$ + In SE(3)-Transformer notation: $\phi^{l,k}$ + + Note: + In the original papers, this function only depends on relative node distances ||x||. + Here, we allow this function to also take as input additional invariant edge features. + This does not break equivariance and adds expressive power to the model. + + Diagram: + invariant edge features (node distances included) ───> MLP layer (shared across edges) ───> radial weights + """ + + def __init__( + self, + num_freq: int, + channels_in: int, + channels_out: int, + edge_dim: int = 1, + mid_dim: int = 32, + use_layer_norm: bool = False + ): + """ + :param num_freq: Number of frequencies + :param channels_in: Number of input channels + :param channels_out: Number of output channels + :param edge_dim: Number of invariant edge features (input to the radial function) + :param mid_dim: Size of the hidden MLP layers + :param use_layer_norm: Apply layer normalization between MLP layers + """ + super().__init__() + modules = [ + nn.Linear(edge_dim, mid_dim), + nn.LayerNorm(mid_dim) if use_layer_norm else None, + nn.ReLU(), + nn.Linear(mid_dim, mid_dim), + nn.LayerNorm(mid_dim) if use_layer_norm else None, + nn.ReLU(), + nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False) + ] + + self.net = nn.Sequential(*[m for m in modules if m is not None]) + + def forward(self, features: Tensor) -> Tensor: + return self.net(features) + + +class VersatileConvSE3(nn.Module): + """ + Building block for TFN convolutions. + This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions. + """ + + def __init__(self, + freq_sum: int, + channels_in: int, + channels_out: int, + edge_dim: int, + use_layer_norm: bool, + fuse_level: ConvSE3FuseLevel): + super().__init__() + self.freq_sum = freq_sum + self.channels_out = channels_out + self.channels_in = channels_in + self.fuse_level = fuse_level + self.radial_func = RadialProfile(num_freq=freq_sum, + channels_in=channels_in, + channels_out=channels_out, + edge_dim=edge_dim, + use_layer_norm=use_layer_norm) + + def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor): + with nvtx_range(f'VersatileConvSE3'): + num_edges = features.shape[0] + in_dim = features.shape[2] + with nvtx_range(f'RadialProfile'): + radial_weights = self.radial_func(invariant_edge_feats) \ + .view(-1, self.channels_out, self.channels_in * self.freq_sum) + + if basis is not None: + # This block performs the einsum n i l, n o i f, n l f k -> n o k + basis_view = basis.view(num_edges, in_dim, -1) + tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1]) + return radial_weights @ tmp + else: + # k = l = 0 non-fused case + return radial_weights @ features + + +class ConvSE3(nn.Module): + """ + SE(3)-equivariant graph convolution (Tensor Field Network convolution). + This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance. + Features of different degrees interact together to produce output features. + + Note 1: + The option is given to not pool the output. This means that the convolution sum over neighbors will not be + done, and the returned features will be edge features instead of node features. + + Note 2: + Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0. + Input edge features are concatenated with input source node features before the kernel is applied. + """ + + def __init__( + self, + fiber_in: Fiber, + fiber_out: Fiber, + fiber_edge: Fiber, + pool: bool = True, + use_layer_norm: bool = False, + self_interaction: bool = False, + max_degree: int = 4, + fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL, + allow_fused_output: bool = False + ): + """ + :param fiber_in: Fiber describing the input features + :param fiber_out: Fiber describing the output features + :param fiber_edge: Fiber describing the edge features (node distances excluded) + :param pool: If True, compute final node features by averaging incoming edge features + :param use_layer_norm: Apply layer normalization between MLP layers + :param self_interaction: Apply self-interaction of nodes + :param max_degree: Maximum degree used in the bases computation + :param fuse_level: Maximum fuse level to use in TFN convolutions + :param allow_fused_output: Allow the module to output a fused representation of features + """ + super().__init__() + self.pool = pool + self.fiber_in = fiber_in + self.fiber_out = fiber_out + self.self_interaction = self_interaction + self.max_degree = max_degree + self.allow_fused_output = allow_fused_output + + # channels_in: account for the concatenation of edge features + channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in]) + channels_out_set = set([f.channels for f in self.fiber_out]) + unique_channels_in = (len(channels_in_set) == 1) + unique_channels_out = (len(channels_out_set) == 1) + degrees_up_to_max = list(range(max_degree + 1)) + common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm) + + if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \ + unique_channels_in and fiber_in.degrees == degrees_up_to_max and \ + unique_channels_out and fiber_out.degrees == degrees_up_to_max: + # Single fused convolution + self.used_fuse_level = ConvSE3FuseLevel.FULL + + sum_freq = sum([ + degree_to_dim(min(d_in, d_out)) + for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max) + ]) + + self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0], + fuse_level=self.used_fuse_level, **common_args) + + elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \ + unique_channels_in and fiber_in.degrees == degrees_up_to_max: + # Convolutions fused per output degree + self.used_fuse_level = ConvSE3FuseLevel.PARTIAL + self.conv_out = nn.ModuleDict() + for d_out, c_out in fiber_out: + sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees]) + self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out, + fuse_level=self.used_fuse_level, **common_args) + + elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \ + unique_channels_out and fiber_out.degrees == degrees_up_to_max: + # Convolutions fused per input degree + self.used_fuse_level = ConvSE3FuseLevel.PARTIAL + self.conv_in = nn.ModuleDict() + for d_in, c_in in fiber_in: + sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees]) + channels_in_new = c_in + fiber_edge[d_in] * (d_in > 0) + self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, channels_in_new, list(channels_out_set)[0], + fuse_level=self.used_fuse_level, **common_args) + else: + # Use pairwise TFN convolutions + self.used_fuse_level = ConvSE3FuseLevel.NONE + self.conv = nn.ModuleDict() + for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out): + dict_key = f'{degree_in},{degree_out}' + channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0) + sum_freq = degree_to_dim(min(degree_in, degree_out)) + self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out, + fuse_level=self.used_fuse_level, **common_args) + + if self_interaction: + self.to_kernel_self = nn.ParameterDict() + for degree_out, channels_out in fiber_out: + if fiber_in[degree_out]: + self.to_kernel_self[str(degree_out)] = nn.Parameter( + torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out])) + + def _try_unpad(self, feature, basis): + # Account for padded basis + if basis is not None: + out_dim = basis.shape[-1] + out_dim += out_dim % 2 - 1 + return feature[..., :out_dim] + else: + return feature + + def forward( + self, + node_feats: Dict[str, Tensor], + edge_feats: Dict[str, Tensor], + graph: DGLGraph, + basis: Dict[str, Tensor] + ): + with nvtx_range(f'ConvSE3'): + invariant_edge_feats = edge_feats['0'].squeeze(-1) + src, dst = graph.edges() + out = {} + in_features = [] + + # Fetch all input features from edge and node features + for degree_in in self.fiber_in.degrees: + src_node_features = node_feats[str(degree_in)][src] + if degree_in > 0 and str(degree_in) in edge_feats: + # Handle edge features of any type by concatenating them to node features + src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1) + in_features.append(src_node_features) + + if self.used_fuse_level == ConvSE3FuseLevel.FULL: + in_features_fused = torch.cat(in_features, dim=-1) + out = self.conv(in_features_fused, invariant_edge_feats, basis['fully_fused']) + + if not self.allow_fused_output or self.self_interaction or self.pool: + out = unfuse_features(out, self.fiber_out.degrees) + + elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'): + in_features_fused = torch.cat(in_features, dim=-1) + for degree_out in self.fiber_out.degrees: + basis_used = basis[f'out{degree_out}_fused'] + out[str(degree_out)] = self._try_unpad( + self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats, basis_used), + basis_used) + + elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'): + out = 0 + for degree_in, feature in zip(self.fiber_in.degrees, in_features): + out = out + self.conv_in[str(degree_in)](feature, invariant_edge_feats, basis[f'in{degree_in}_fused']) + if not self.allow_fused_output or self.self_interaction or self.pool: + out = unfuse_features(out, self.fiber_out.degrees) + else: + # Fallback to pairwise TFN convolutions + for degree_out in self.fiber_out.degrees: + out_feature = 0 + for degree_in, feature in zip(self.fiber_in.degrees, in_features): + dict_key = f'{degree_in},{degree_out}' + basis_used = basis.get(dict_key, None) + out_feature = out_feature + self._try_unpad( + self.conv[dict_key](feature, invariant_edge_feats, basis_used), + basis_used) + out[str(degree_out)] = out_feature + + for degree_out in self.fiber_out.degrees: + if self.self_interaction and str(degree_out) in self.to_kernel_self: + with nvtx_range(f'self interaction'): + dst_features = node_feats[str(degree_out)][dst] + kernel_self = self.to_kernel_self[str(degree_out)] + out[str(degree_out)] = out[str(degree_out)] + kernel_self @ dst_features + + if self.pool: + with nvtx_range(f'pooling'): + if isinstance(out, dict): + out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)]) + else: + out = dgl.ops.copy_e_sum(graph, out) + return out diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/linear.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/linear.py index f720d77e..c138d897 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/linear.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/linear.py @@ -1,59 +1,59 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - - -from typing import Dict - -import numpy as np -import torch -import torch.nn as nn -from torch import Tensor - -from se3_transformer.model.fiber import Fiber - - -class LinearSE3(nn.Module): - """ - Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution. - Maps a fiber to a fiber with the same degrees (channels may be different). - No interaction between degrees, but interaction between channels. - - type-0 features (C_0 channels) ────> Linear(bias=False) ────> type-0 features (C'_0 channels) - type-1 features (C_1 channels) ────> Linear(bias=False) ────> type-1 features (C'_1 channels) - : - type-k features (C_k channels) ────> Linear(bias=False) ────> type-k features (C'_k channels) - """ - - def __init__(self, fiber_in: Fiber, fiber_out: Fiber): - super().__init__() - self.weights = nn.ParameterDict({ - str(degree_out): nn.Parameter( - torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out])) - for degree_out, channels_out in fiber_out - }) - - def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]: - return { - degree: self.weights[degree] @ features[degree] - for degree, weight in self.weights.items() - } +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + + +from typing import Dict + +import numpy as np +import torch +import torch.nn as nn +from torch import Tensor + +from se3_transformer.model.fiber import Fiber + + +class LinearSE3(nn.Module): + """ + Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution. + Maps a fiber to a fiber with the same degrees (channels may be different). + No interaction between degrees, but interaction between channels. + + type-0 features (C_0 channels) ────> Linear(bias=False) ────> type-0 features (C'_0 channels) + type-1 features (C_1 channels) ────> Linear(bias=False) ────> type-1 features (C'_1 channels) + : + type-k features (C_k channels) ────> Linear(bias=False) ────> type-k features (C'_k channels) + """ + + def __init__(self, fiber_in: Fiber, fiber_out: Fiber): + super().__init__() + self.weights = nn.ParameterDict({ + str(degree_out): nn.Parameter( + torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out])) + for degree_out, channels_out in fiber_out + }) + + def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]: + return { + degree: self.weights[degree] @ features[degree] + for degree, weight in self.weights.items() + } diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/norm.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/norm.py index acbe23d7..71c753f6 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/norm.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/norm.py @@ -1,83 +1,83 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - - -from typing import Dict - -import torch -import torch.nn as nn -from torch import Tensor -from torch.cuda.nvtx import range as nvtx_range - -from se3_transformer.model.fiber import Fiber - - -class NormSE3(nn.Module): - """ - Norm-based SE(3)-equivariant nonlinearity. - - ┌──> feature_norm ──> LayerNorm() ──> ReLU() ──┐ - feature_in ──┤ * ──> feature_out - └──> feature_phase ────────────────────────────┘ - """ - - NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16 - - def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()): - super().__init__() - self.fiber = fiber - self.nonlinearity = nonlinearity - - if len(set(fiber.channels)) == 1: - # Fuse all the layer normalizations into a group normalization - self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels)) - else: - # Use multiple layer normalizations - self.layer_norms = nn.ModuleDict({ - str(degree): nn.LayerNorm(channels) - for degree, channels in fiber - }) - - def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]: - with nvtx_range('NormSE3'): - output = {} - if hasattr(self, 'group_norm'): - # Compute per-degree norms of features - norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP) - for d in self.fiber.degrees] - fused_norms = torch.cat(norms, dim=-2) - - # Transform the norms only - new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1) - new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2) - - # Scale features to the new norms - for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees): - output[str(d)] = features[str(d)] / norm * new_norm - else: - for degree, feat in features.items(): - norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP) - new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1)) - output[degree] = new_norm * feat / norm - - return output +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + + +from typing import Dict + +import torch +import torch.nn as nn +from torch import Tensor +from torch.cuda.nvtx import range as nvtx_range + +from se3_transformer.model.fiber import Fiber + + +class NormSE3(nn.Module): + """ + Norm-based SE(3)-equivariant nonlinearity. + + ┌──> feature_norm ──> LayerNorm() ──> ReLU() ──┐ + feature_in ──┤ * ──> feature_out + └──> feature_phase ────────────────────────────┘ + """ + + NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16 + + def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()): + super().__init__() + self.fiber = fiber + self.nonlinearity = nonlinearity + + if len(set(fiber.channels)) == 1: + # Fuse all the layer normalizations into a group normalization + self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels)) + else: + # Use multiple layer normalizations + self.layer_norms = nn.ModuleDict({ + str(degree): nn.LayerNorm(channels) + for degree, channels in fiber + }) + + def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]: + with nvtx_range('NormSE3'): + output = {} + if hasattr(self, 'group_norm'): + # Compute per-degree norms of features + norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP) + for d in self.fiber.degrees] + fused_norms = torch.cat(norms, dim=-2) + + # Transform the norms only + new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1) + new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2) + + # Scale features to the new norms + for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees): + output[str(d)] = features[str(d)] / norm * new_norm + else: + for degree, feat in features.items(): + norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP) + new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1)) + output[degree] = new_norm * feat / norm + + return output diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/pooling.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/pooling.py index e42c5383..273aab6c 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/pooling.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/pooling.py @@ -1,53 +1,53 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - -from typing import Dict, Literal - -import torch.nn as nn -from dgl import DGLGraph -from dgl.nn.pytorch import AvgPooling, MaxPooling -from torch import Tensor - - -class GPooling(nn.Module): - """ - Graph max/average pooling on a given feature type. - The average can be taken for any feature type, and equivariance will be maintained. - The maximum can only be taken for invariant features (type 0). - If you want max-pooling for type > 0 features, look into Vector Neurons. - """ - - def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'): - """ - :param feat_type: Feature type to pool - :param pool: Type of pooling: max or avg - """ - super().__init__() - assert pool in ['max', 'avg'], f'Unknown pooling: {pool}' - assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance' - self.feat_type = feat_type - self.pool = MaxPooling() if pool == 'max' else AvgPooling() - - def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor: - pooled = self.pool(graph, features[str(self.feat_type)]) - return pooled.squeeze(dim=-1) +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + +from typing import Dict, Literal + +import torch.nn as nn +from dgl import DGLGraph +from dgl.nn.pytorch import AvgPooling, MaxPooling +from torch import Tensor + + +class GPooling(nn.Module): + """ + Graph max/average pooling on a given feature type. + The average can be taken for any feature type, and equivariance will be maintained. + The maximum can only be taken for invariant features (type 0). + If you want max-pooling for type > 0 features, look into Vector Neurons. + """ + + def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'): + """ + :param feat_type: Feature type to pool + :param pool: Type of pooling: max or avg + """ + super().__init__() + assert pool in ['max', 'avg'], f'Unknown pooling: {pool}' + assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance' + self.feat_type = feat_type + self.pool = MaxPooling() if pool == 'max' else AvgPooling() + + def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor: + pooled = self.pool(graph, features[str(self.feat_type)]) + return pooled.squeeze(dim=-1) diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/transformer.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/transformer.py index ab89f77a..bec7cc8e 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/transformer.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/transformer.py @@ -1,222 +1,222 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - -import logging -from typing import Optional, Literal, Dict - -import torch -import torch.nn as nn -from dgl import DGLGraph -from torch import Tensor - -from se3_transformer.model.basis import get_basis, update_basis_with_fused -from se3_transformer.model.layers.attention import AttentionBlockSE3 -from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel -from se3_transformer.model.layers.norm import NormSE3 -from se3_transformer.model.layers.pooling import GPooling -from se3_transformer.runtime.utils import str2bool -from se3_transformer.model.fiber import Fiber - - -class Sequential(nn.Sequential): - """ Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """ - - def forward(self, input, *args, **kwargs): - for module in self: - input = module(input, *args, **kwargs) - return input - - -def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None): - """ Add relative positions to existing edge features """ - edge_features = edge_features.copy() if edge_features else {} - r = relative_pos.norm(dim=-1, keepdim=True) - if '0' in edge_features: - edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1) - else: - edge_features['0'] = r[..., None] - - return edge_features - - -class SE3Transformer(nn.Module): - def __init__(self, - num_layers: int, - fiber_in: Fiber, - fiber_hidden: Fiber, - fiber_out: Fiber, - num_heads: int, - channels_div: int, - fiber_edge: Fiber = Fiber({}), - return_type: Optional[int] = None, - pooling: Optional[Literal['avg', 'max']] = None, - norm: bool = True, - use_layer_norm: bool = True, - tensor_cores: bool = False, - low_memory: bool = False, - **kwargs): - """ - :param num_layers: Number of attention layers - :param fiber_in: Input fiber description - :param fiber_hidden: Hidden fiber description - :param fiber_out: Output fiber description - :param fiber_edge: Input edge fiber description - :param num_heads: Number of attention heads - :param channels_div: Channels division before feeding to attention layer - :param return_type: Return only features of this type - :param pooling: 'avg' or 'max' graph pooling before MLP layers - :param norm: Apply a normalization layer after each attention block - :param use_layer_norm: Apply layer normalization between MLP layers - :param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases) - :param low_memory: If True, will use slower ops that use less memory - """ - super().__init__() - self.num_layers = num_layers - self.fiber_edge = fiber_edge - self.num_heads = num_heads - self.channels_div = channels_div - self.return_type = return_type - self.pooling = pooling - self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees) - self.tensor_cores = tensor_cores - self.low_memory = low_memory - - if low_memory and not tensor_cores: - logging.warning('Low memory mode will have no effect with no Tensor Cores') - - # Fully fused convolutions when using Tensor Cores (and not low memory mode) - fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL - - graph_modules = [] - for i in range(num_layers): - graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in, - fiber_out=fiber_hidden, - fiber_edge=fiber_edge, - num_heads=num_heads, - channels_div=channels_div, - use_layer_norm=use_layer_norm, - max_degree=self.max_degree, - fuse_level=fuse_level)) - if norm: - graph_modules.append(NormSE3(fiber_hidden)) - fiber_in = fiber_hidden - - graph_modules.append(ConvSE3(fiber_in=fiber_in, - fiber_out=fiber_out, - fiber_edge=fiber_edge, - self_interaction=True, - use_layer_norm=use_layer_norm, - max_degree=self.max_degree)) - self.graph_modules = Sequential(*graph_modules) - - if pooling is not None: - assert return_type is not None, 'return_type must be specified when pooling' - self.pooling_module = GPooling(pool=pooling, feat_type=return_type) - - def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor], - edge_feats: Optional[Dict[str, Tensor]] = None, - basis: Optional[Dict[str, Tensor]] = None): - # Compute bases in case they weren't precomputed as part of the data loading - basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False, - use_pad_trick=self.tensor_cores and not self.low_memory, - amp=torch.is_autocast_enabled()) - - # Add fused bases (per output degree, per input degree, and fully fused) to the dict - basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory, - fully_fused=self.tensor_cores and not self.low_memory) - - edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats) - - node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis) - - if self.pooling is not None: - return self.pooling_module(node_feats, graph=graph) - - if self.return_type is not None: - return node_feats[str(self.return_type)] - - return node_feats - - @staticmethod - def add_argparse_args(parser): - parser.add_argument('--num_layers', type=int, default=7, - help='Number of stacked Transformer layers') - parser.add_argument('--num_heads', type=int, default=8, - help='Number of heads in self-attention') - parser.add_argument('--channels_div', type=int, default=2, - help='Channels division before feeding to attention layer') - parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'], - help='Type of graph pooling') - parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False, - help='Apply a normalization layer after each attention block') - parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False, - help='Apply layer normalization between MLP layers') - parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False, - help='If true, will use fused ops that are slower but that use less memory ' - '(expect 25 percent less memory). ' - 'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs') - - return parser - - -class SE3TransformerPooled(nn.Module): - def __init__(self, - fiber_in: Fiber, - fiber_out: Fiber, - fiber_edge: Fiber, - num_degrees: int, - num_channels: int, - output_dim: int, - **kwargs): - super().__init__() - kwargs['pooling'] = kwargs['pooling'] or 'max' - self.transformer = SE3Transformer( - fiber_in=fiber_in, - fiber_hidden=Fiber.create(num_degrees, num_channels), - fiber_out=fiber_out, - fiber_edge=fiber_edge, - return_type=0, - **kwargs - ) - - n_out_features = fiber_out.num_features - self.mlp = nn.Sequential( - nn.Linear(n_out_features, n_out_features), - nn.ReLU(), - nn.Linear(n_out_features, output_dim) - ) - - def forward(self, graph, node_feats, edge_feats, basis=None): - feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1) - y = self.mlp(feats).squeeze(-1) - return y - - @staticmethod - def add_argparse_args(parent_parser): - parser = parent_parser.add_argument_group("Model architecture") - SE3Transformer.add_argparse_args(parser) - parser.add_argument('--num_degrees', - help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]', - type=int, default=4) - parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32) - return parent_parser +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + +import logging +from typing import Optional, Literal, Dict + +import torch +import torch.nn as nn +from dgl import DGLGraph +from torch import Tensor + +from se3_transformer.model.basis import get_basis, update_basis_with_fused +from se3_transformer.model.layers.attention import AttentionBlockSE3 +from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel +from se3_transformer.model.layers.norm import NormSE3 +from se3_transformer.model.layers.pooling import GPooling +from se3_transformer.runtime.utils import str2bool +from se3_transformer.model.fiber import Fiber + + +class Sequential(nn.Sequential): + """ Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """ + + def forward(self, input, *args, **kwargs): + for module in self: + input = module(input, *args, **kwargs) + return input + + +def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None): + """ Add relative positions to existing edge features """ + edge_features = edge_features.copy() if edge_features else {} + r = relative_pos.norm(dim=-1, keepdim=True) + if '0' in edge_features: + edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1) + else: + edge_features['0'] = r[..., None] + + return edge_features + + +class SE3Transformer(nn.Module): + def __init__(self, + num_layers: int, + fiber_in: Fiber, + fiber_hidden: Fiber, + fiber_out: Fiber, + num_heads: int, + channels_div: int, + fiber_edge: Fiber = Fiber({}), + return_type: Optional[int] = None, + pooling: Optional[Literal['avg', 'max']] = None, + norm: bool = True, + use_layer_norm: bool = True, + tensor_cores: bool = False, + low_memory: bool = False, + **kwargs): + """ + :param num_layers: Number of attention layers + :param fiber_in: Input fiber description + :param fiber_hidden: Hidden fiber description + :param fiber_out: Output fiber description + :param fiber_edge: Input edge fiber description + :param num_heads: Number of attention heads + :param channels_div: Channels division before feeding to attention layer + :param return_type: Return only features of this type + :param pooling: 'avg' or 'max' graph pooling before MLP layers + :param norm: Apply a normalization layer after each attention block + :param use_layer_norm: Apply layer normalization between MLP layers + :param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases) + :param low_memory: If True, will use slower ops that use less memory + """ + super().__init__() + self.num_layers = num_layers + self.fiber_edge = fiber_edge + self.num_heads = num_heads + self.channels_div = channels_div + self.return_type = return_type + self.pooling = pooling + self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees) + self.tensor_cores = tensor_cores + self.low_memory = low_memory + + if low_memory and not tensor_cores: + logging.warning('Low memory mode will have no effect with no Tensor Cores') + + # Fully fused convolutions when using Tensor Cores (and not low memory mode) + fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL + + graph_modules = [] + for i in range(num_layers): + graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in, + fiber_out=fiber_hidden, + fiber_edge=fiber_edge, + num_heads=num_heads, + channels_div=channels_div, + use_layer_norm=use_layer_norm, + max_degree=self.max_degree, + fuse_level=fuse_level)) + if norm: + graph_modules.append(NormSE3(fiber_hidden)) + fiber_in = fiber_hidden + + graph_modules.append(ConvSE3(fiber_in=fiber_in, + fiber_out=fiber_out, + fiber_edge=fiber_edge, + self_interaction=True, + use_layer_norm=use_layer_norm, + max_degree=self.max_degree)) + self.graph_modules = Sequential(*graph_modules) + + if pooling is not None: + assert return_type is not None, 'return_type must be specified when pooling' + self.pooling_module = GPooling(pool=pooling, feat_type=return_type) + + def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor], + edge_feats: Optional[Dict[str, Tensor]] = None, + basis: Optional[Dict[str, Tensor]] = None): + # Compute bases in case they weren't precomputed as part of the data loading + basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False, + use_pad_trick=self.tensor_cores and not self.low_memory, + amp=torch.is_autocast_enabled()) + + # Add fused bases (per output degree, per input degree, and fully fused) to the dict + basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory, + fully_fused=self.tensor_cores and not self.low_memory) + + edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats) + + node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis) + + if self.pooling is not None: + return self.pooling_module(node_feats, graph=graph) + + if self.return_type is not None: + return node_feats[str(self.return_type)] + + return node_feats + + @staticmethod + def add_argparse_args(parser): + parser.add_argument('--num_layers', type=int, default=7, + help='Number of stacked Transformer layers') + parser.add_argument('--num_heads', type=int, default=8, + help='Number of heads in self-attention') + parser.add_argument('--channels_div', type=int, default=2, + help='Channels division before feeding to attention layer') + parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'], + help='Type of graph pooling') + parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False, + help='Apply a normalization layer after each attention block') + parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False, + help='Apply layer normalization between MLP layers') + parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False, + help='If true, will use fused ops that are slower but that use less memory ' + '(expect 25 percent less memory). ' + 'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs') + + return parser + + +class SE3TransformerPooled(nn.Module): + def __init__(self, + fiber_in: Fiber, + fiber_out: Fiber, + fiber_edge: Fiber, + num_degrees: int, + num_channels: int, + output_dim: int, + **kwargs): + super().__init__() + kwargs['pooling'] = kwargs['pooling'] or 'max' + self.transformer = SE3Transformer( + fiber_in=fiber_in, + fiber_hidden=Fiber.create(num_degrees, num_channels), + fiber_out=fiber_out, + fiber_edge=fiber_edge, + return_type=0, + **kwargs + ) + + n_out_features = fiber_out.num_features + self.mlp = nn.Sequential( + nn.Linear(n_out_features, n_out_features), + nn.ReLU(), + nn.Linear(n_out_features, output_dim) + ) + + def forward(self, graph, node_feats, edge_feats, basis=None): + feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1) + y = self.mlp(feats).squeeze(-1) + return y + + @staticmethod + def add_argparse_args(parent_parser): + parser = parent_parser.add_argument_group("Model architecture") + SE3Transformer.add_argparse_args(parser) + parser.add_argument('--num_degrees', + help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]', + type=int, default=4) + parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32) + return parent_parser diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/arguments.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/arguments.py index d35d5ee1..7730ddcf 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/arguments.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/arguments.py @@ -1,70 +1,70 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - -import argparse -import pathlib - -from se3_transformer.data_loading import QM9DataModule -from se3_transformer.model import SE3TransformerPooled -from se3_transformer.runtime.utils import str2bool - -PARSER = argparse.ArgumentParser(description='SE(3)-Transformer') - -paths = PARSER.add_argument_group('Paths') -paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'), - help='Directory where the data is located or should be downloaded') -paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'), - help='Directory where the results logs should be saved') -paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json', - help='Name for the resulting DLLogger JSON file') -paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None, - help='File where the checkpoint should be saved') -paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None, - help='File of the checkpoint to be loaded') - -optimizer = PARSER.add_argument_group('Optimizer') -optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam') -optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002) -optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None) -optimizer.add_argument('--momentum', type=float, default=0.9) -optimizer.add_argument('--weight_decay', type=float, default=0.1) - -PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs') -PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size') -PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally') -PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers') - -PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision') -PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms') -PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation') -PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs') -PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=1, - help='Do an evaluation round every N epochs') -PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False, - help='Minimize stdout output') - -PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False, - help='Benchmark mode') - -QM9DataModule.add_argparse_args(PARSER) -SE3TransformerPooled.add_argparse_args(PARSER) +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + +import argparse +import pathlib + +from se3_transformer.data_loading import QM9DataModule +from se3_transformer.model import SE3TransformerPooled +from se3_transformer.runtime.utils import str2bool + +PARSER = argparse.ArgumentParser(description='SE(3)-Transformer') + +paths = PARSER.add_argument_group('Paths') +paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'), + help='Directory where the data is located or should be downloaded') +paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'), + help='Directory where the results logs should be saved') +paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json', + help='Name for the resulting DLLogger JSON file') +paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None, + help='File where the checkpoint should be saved') +paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None, + help='File of the checkpoint to be loaded') + +optimizer = PARSER.add_argument_group('Optimizer') +optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam') +optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002) +optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None) +optimizer.add_argument('--momentum', type=float, default=0.9) +optimizer.add_argument('--weight_decay', type=float, default=0.1) + +PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs') +PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size') +PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally') +PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers') + +PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision') +PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms') +PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation') +PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs') +PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=20, + help='Do an evaluation round every N epochs') +PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False, + help='Minimize stdout output') + +PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False, + help='Benchmark mode') + +QM9DataModule.add_argparse_args(PARSER) +SE3TransformerPooled.add_argparse_args(PARSER) diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/training.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/training.py index 4e41fbe8..ebb4f501 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/training.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/training.py @@ -1,240 +1,240 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - -import logging -import pathlib -from typing import List - -import numpy as np -import torch -import torch.distributed as dist -import torch.nn as nn -from apex.optimizers import FusedAdam, FusedLAMB -from torch.nn.modules.loss import _Loss -from torch.nn.parallel import DistributedDataParallel -from torch.optim import Optimizer -from torch.utils.data import DataLoader, DistributedSampler -from tqdm import tqdm - -from se3_transformer.data_loading import QM9DataModule -from se3_transformer.model import SE3TransformerPooled -from se3_transformer.model.fiber import Fiber -from se3_transformer.runtime import gpu_affinity -from se3_transformer.runtime.arguments import PARSER -from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \ - PerformanceCallback -from se3_transformer.runtime.inference import evaluate -from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger -from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \ - using_tensor_cores, increase_l2_fetch_granularity - - -def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]): - """ Saves model, optimizer and epoch states to path (only once per node) """ - if get_local_rank() == 0: - state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict() - checkpoint = { - 'state_dict': state_dict, - 'optimizer_state_dict': optimizer.state_dict(), - 'epoch': epoch - } - for callback in callbacks: - callback.on_checkpoint_save(checkpoint) - - torch.save(checkpoint, str(path)) - logging.info(f'Saved checkpoint to {str(path)}') - - -def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]): - """ Loads model, optimizer and epoch states from path """ - checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'}) - if isinstance(model, DistributedDataParallel): - model.module.load_state_dict(checkpoint['state_dict']) - else: - model.load_state_dict(checkpoint['state_dict']) - optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - - for callback in callbacks: - callback.on_checkpoint_load(checkpoint) - - logging.info(f'Loaded checkpoint from {str(path)}') - return checkpoint['epoch'] - - -def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args): - losses = [] - for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch', - desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)): - *inputs, target = to_cuda(batch) - - for callback in callbacks: - callback.on_batch_start() - - with torch.cuda.amp.autocast(enabled=args.amp): - pred = model(*inputs) - loss = loss_fn(pred, target) / args.accumulate_grad_batches - - grad_scaler.scale(loss).backward() - - # gradient accumulation - if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader): - if args.gradient_clip: - grad_scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip) - - grad_scaler.step(optimizer) - grad_scaler.update() - model.zero_grad(set_to_none=True) - - losses.append(loss.item()) - - return np.mean(losses) - - -def train(model: nn.Module, - loss_fn: _Loss, - train_dataloader: DataLoader, - val_dataloader: DataLoader, - callbacks: List[BaseCallback], - logger: Logger, - args): - device = torch.cuda.current_device() - model.to(device=device) - local_rank = get_local_rank() - world_size = dist.get_world_size() if dist.is_initialized() else 1 - - if dist.is_initialized(): - model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) - - model.train() - grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp) - if args.optimizer == 'adam': - optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), - weight_decay=args.weight_decay) - elif args.optimizer == 'lamb': - optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), - weight_decay=args.weight_decay) - else: - optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, - weight_decay=args.weight_decay) - - epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0 - - for callback in callbacks: - callback.on_fit_start(optimizer, args) - - for epoch_idx in range(epoch_start, args.epochs): - if isinstance(train_dataloader.sampler, DistributedSampler): - train_dataloader.sampler.set_epoch(epoch_idx) - - loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args) - if dist.is_initialized(): - loss = torch.tensor(loss, dtype=torch.float, device=device) - torch.distributed.all_reduce(loss) - loss = (loss / world_size).item() - - logging.info(f'Train loss: {loss}') - logger.log_metrics({'train loss': loss}, epoch_idx) - - for callback in callbacks: - callback.on_epoch_end() - - if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \ - and (epoch_idx + 1) % args.ckpt_interval == 0: - save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks) - - if not args.benchmark and ((args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0) or epoch_idx + 1 == args.epochs): - evaluate(model, val_dataloader, callbacks, args) - model.train() - - for callback in callbacks: - callback.on_validation_end(epoch_idx) - - if args.save_ckpt_path is not None and not args.benchmark: - save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks) - - for callback in callbacks: - callback.on_fit_end() - - -def print_parameters_count(model): - num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) - logging.info(f'Number of trainable parameters: {num_params_trainable}') - - -if __name__ == '__main__': - is_distributed = init_distributed() - local_rank = get_local_rank() - args = PARSER.parse_args() - - logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO) - - logging.info('====== SE(3)-Transformer ======') - logging.info('| Training procedure |') - logging.info('===============================') - - if args.seed is not None: - logging.info(f'Using seed {args.seed}') - seed_everything(args.seed) - - logger = LoggerCollection([ - DLLogger(save_dir=args.log_dir, filename=args.dllogger_name), - WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer') - ]) - - datamodule = QM9DataModule(**vars(args)) - model = SE3TransformerPooled( - fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}), - fiber_out=Fiber({0: args.num_degrees * args.num_channels}), - fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}), - output_dim=1, - tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively - **vars(args) - ) - loss_fn = nn.L1Loss() - - if args.benchmark: - logging.info('Running benchmark mode') - world_size = dist.get_world_size() if dist.is_initialized() else 1 - callbacks = [PerformanceCallback(logger, args.batch_size * world_size)] - else: - callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'), - QM9LRSchedulerCallback(logger, epochs=args.epochs)] - - if is_distributed: - gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count()) - - print_parameters_count(model) - logger.log_hyperparams(vars(args)) - increase_l2_fetch_granularity() - train(model, - loss_fn, - datamodule.train_dataloader(), - datamodule.val_dataloader(), - callbacks, - logger, - args) - - logging.info('Training finished successfully') - - +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + +import logging +import pathlib +from typing import List + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from apex.optimizers import FusedAdam, FusedLAMB +from torch.nn.modules.loss import _Loss +from torch.nn.parallel import DistributedDataParallel +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from tqdm import tqdm + +from se3_transformer.data_loading import QM9DataModule +from se3_transformer.model import SE3TransformerPooled +from se3_transformer.model.fiber import Fiber +from se3_transformer.runtime import gpu_affinity +from se3_transformer.runtime.arguments import PARSER +from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \ + PerformanceCallback +from se3_transformer.runtime.inference import evaluate +from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger +from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \ + using_tensor_cores, increase_l2_fetch_granularity + + +def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]): + """ Saves model, optimizer and epoch states to path (only once per node) """ + if get_local_rank() == 0: + state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict() + checkpoint = { + 'state_dict': state_dict, + 'optimizer_state_dict': optimizer.state_dict(), + 'epoch': epoch + } + for callback in callbacks: + callback.on_checkpoint_save(checkpoint) + + torch.save(checkpoint, str(path)) + logging.info(f'Saved checkpoint to {str(path)}') + + +def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]): + """ Loads model, optimizer and epoch states from path """ + checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'}) + if isinstance(model, DistributedDataParallel): + model.module.load_state_dict(checkpoint['state_dict']) + else: + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + for callback in callbacks: + callback.on_checkpoint_load(checkpoint) + + logging.info(f'Loaded checkpoint from {str(path)}') + return checkpoint['epoch'] + + +def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args): + losses = [] + for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch', + desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)): + *inputs, target = to_cuda(batch) + + for callback in callbacks: + callback.on_batch_start() + + with torch.cuda.amp.autocast(enabled=args.amp): + pred = model(*inputs) + loss = loss_fn(pred, target) / args.accumulate_grad_batches + + grad_scaler.scale(loss).backward() + + # gradient accumulation + if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader): + if args.gradient_clip: + grad_scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip) + + grad_scaler.step(optimizer) + grad_scaler.update() + model.zero_grad(set_to_none=True) + + losses.append(loss.item()) + + return np.mean(losses) + + +def train(model: nn.Module, + loss_fn: _Loss, + train_dataloader: DataLoader, + val_dataloader: DataLoader, + callbacks: List[BaseCallback], + logger: Logger, + args): + device = torch.cuda.current_device() + model.to(device=device) + local_rank = get_local_rank() + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + if dist.is_initialized(): + model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) + + model.train() + grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp) + if args.optimizer == 'adam': + optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), + weight_decay=args.weight_decay) + elif args.optimizer == 'lamb': + optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), + weight_decay=args.weight_decay) + else: + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, + weight_decay=args.weight_decay) + + epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0 + + for callback in callbacks: + callback.on_fit_start(optimizer, args) + + for epoch_idx in range(epoch_start, args.epochs): + if isinstance(train_dataloader.sampler, DistributedSampler): + train_dataloader.sampler.set_epoch(epoch_idx) + + loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args) + if dist.is_initialized(): + loss = torch.tensor(loss, dtype=torch.float, device=device) + torch.distributed.all_reduce(loss) + loss = (loss / world_size).item() + + logging.info(f'Train loss: {loss}') + logger.log_metrics({'train loss': loss}, epoch_idx) + + for callback in callbacks: + callback.on_epoch_end() + + if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \ + and (epoch_idx + 1) % args.ckpt_interval == 0: + save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks) + + if not args.benchmark and ((args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0) or epoch_idx + 1 == args.epochs): + evaluate(model, val_dataloader, callbacks, args) + model.train() + + for callback in callbacks: + callback.on_validation_end(epoch_idx) + + if args.save_ckpt_path is not None and not args.benchmark: + save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks) + + for callback in callbacks: + callback.on_fit_end() + + +def print_parameters_count(model): + num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + logging.info(f'Number of trainable parameters: {num_params_trainable}') + + +if __name__ == '__main__': + is_distributed = init_distributed() + local_rank = get_local_rank() + args = PARSER.parse_args() + + logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO) + + logging.info('====== SE(3)-Transformer ======') + logging.info('| Training procedure |') + logging.info('===============================') + + if args.seed is not None: + logging.info(f'Using seed {args.seed}') + seed_everything(args.seed) + + logger = LoggerCollection([ + DLLogger(save_dir=args.log_dir, filename=args.dllogger_name), + WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer') + ]) + + datamodule = QM9DataModule(**vars(args)) + model = SE3TransformerPooled( + fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}), + fiber_out=Fiber({0: args.num_degrees * args.num_channels}), + fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}), + output_dim=1, + tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively + **vars(args) + ) + loss_fn = nn.L1Loss() + + if args.benchmark: + logging.info('Running benchmark mode') + world_size = dist.get_world_size() if dist.is_initialized() else 1 + callbacks = [PerformanceCallback(logger, args.batch_size * world_size)] + else: + callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'), + QM9LRSchedulerCallback(logger, epochs=args.epochs)] + + if is_distributed: + gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count()) + + print_parameters_count(model) + logger.log_hyperparams(vars(args)) + increase_l2_fetch_granularity() + train(model, + loss_fn, + datamodule.train_dataloader(), + datamodule.val_dataloader(), + callbacks, + logger, + args) + + logging.info('Training finished successfully') + + diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/tests/test_equivariance.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/tests/test_equivariance.py index a0a29b7f..f19e4d94 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/tests/test_equivariance.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/tests/test_equivariance.py @@ -1,102 +1,102 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - -import torch - -from se3_transformer.model import SE3Transformer -from se3_transformer.model.fiber import Fiber -from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot - -# Tolerances for equivariance error abs( f(x) @ R - f(x @ R) ) -TOL = 1e-3 -CHANNELS, NODES = 32, 512 - - -def _get_outputs(model, R): - feats0 = torch.randn(NODES, CHANNELS, 1) - feats1 = torch.randn(NODES, CHANNELS, 3) - - coords = torch.randn(NODES, 3) - graph = get_random_graph(NODES) - if torch.cuda.is_available(): - feats0 = feats0.cuda() - feats1 = feats1.cuda() - R = R.cuda() - coords = coords.cuda() - graph = graph.to('cuda') - model.cuda() - - graph1 = assign_relative_pos(graph, coords) - out1 = model(graph1, {'0': feats0, '1': feats1}, {}) - graph2 = assign_relative_pos(graph, coords @ R) - out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {}) - - return out1, out2 - - -def _get_model(**kwargs): - return SE3Transformer( - num_layers=4, - fiber_in=Fiber.create(2, CHANNELS), - fiber_hidden=Fiber.create(3, CHANNELS), - fiber_out=Fiber.create(2, CHANNELS), - fiber_edge=Fiber({}), - num_heads=8, - channels_div=2, - **kwargs - ) - - -def test_equivariance(): - model = _get_model() - R = rot(*torch.rand(3)) - if torch.cuda.is_available(): - R = R.cuda() - out1, out2 = _get_outputs(model, R) - - assert torch.allclose(out2['0'], out1['0'], atol=TOL), \ - f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}' - assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \ - f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}' - - -def test_equivariance_pooled(): - model = _get_model(pooling='avg', return_type=1) - R = rot(*torch.rand(3)) - if torch.cuda.is_available(): - R = R.cuda() - out1, out2 = _get_outputs(model, R) - - assert torch.allclose(out2, (out1 @ R), atol=TOL), \ - f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}' - - -def test_invariance_pooled(): - model = _get_model(pooling='avg', return_type=0) - R = rot(*torch.rand(3)) - if torch.cuda.is_available(): - R = R.cuda() - out1, out2 = _get_outputs(model, R) - - assert torch.allclose(out2, out1, atol=TOL), \ - f'type-0 features should be invariant {get_max_diff(out1, out2)}' +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + +import torch + +from se3_transformer.model import SE3Transformer +from se3_transformer.model.fiber import Fiber +from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot + +# Tolerances for equivariance error abs( f(x) @ R - f(x @ R) ) +TOL = 1e-3 +CHANNELS, NODES = 32, 512 + + +def _get_outputs(model, R): + feats0 = torch.randn(NODES, CHANNELS, 1) + feats1 = torch.randn(NODES, CHANNELS, 3) + + coords = torch.randn(NODES, 3) + graph = get_random_graph(NODES) + if torch.cuda.is_available(): + feats0 = feats0.cuda() + feats1 = feats1.cuda() + R = R.cuda() + coords = coords.cuda() + graph = graph.to('cuda') + model.cuda() + + graph1 = assign_relative_pos(graph, coords) + out1 = model(graph1, {'0': feats0, '1': feats1}, {}) + graph2 = assign_relative_pos(graph, coords @ R) + out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {}) + + return out1, out2 + + +def _get_model(**kwargs): + return SE3Transformer( + num_layers=4, + fiber_in=Fiber.create(2, CHANNELS), + fiber_hidden=Fiber.create(3, CHANNELS), + fiber_out=Fiber.create(2, CHANNELS), + fiber_edge=Fiber({}), + num_heads=8, + channels_div=2, + **kwargs + ) + + +def test_equivariance(): + model = _get_model() + R = rot(*torch.rand(3)) + if torch.cuda.is_available(): + R = R.cuda() + out1, out2 = _get_outputs(model, R) + + assert torch.allclose(out2['0'], out1['0'], atol=TOL), \ + f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}' + assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \ + f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}' + + +def test_equivariance_pooled(): + model = _get_model(pooling='avg', return_type=1) + R = rot(*torch.rand(3)) + if torch.cuda.is_available(): + R = R.cuda() + out1, out2 = _get_outputs(model, R) + + assert torch.allclose(out2, (out1 @ R), atol=TOL), \ + f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}' + + +def test_invariance_pooled(): + model = _get_model(pooling='avg', return_type=0) + R = rot(*torch.rand(3)) + if torch.cuda.is_available(): + R = R.cuda() + out1, out2 = _get_outputs(model, R) + + assert torch.allclose(out2, out1, atol=TOL), \ + f'type-0 features should be invariant {get_max_diff(out1, out2)}' diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/tests/utils.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/tests/utils.py index 195f0aef..d72bebc3 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/tests/utils.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/tests/utils.py @@ -1,60 +1,60 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - -import dgl -import torch - - -def get_random_graph(N, num_edges_factor=18): - graph = dgl.transform.remove_self_loop(dgl.rand_graph(N, N * num_edges_factor)) - return graph - - -def assign_relative_pos(graph, coords): - src, dst = graph.edges() - graph.edata['rel_pos'] = coords[src] - coords[dst] - return graph - - -def get_max_diff(a, b): - return (a - b).abs().max().item() - - -def rot_z(gamma): - return torch.tensor([ - [torch.cos(gamma), -torch.sin(gamma), 0], - [torch.sin(gamma), torch.cos(gamma), 0], - [0, 0, 1] - ], dtype=gamma.dtype) - - -def rot_y(beta): - return torch.tensor([ - [torch.cos(beta), 0, torch.sin(beta)], - [0, 1, 0], - [-torch.sin(beta), 0, torch.cos(beta)] - ], dtype=beta.dtype) - - -def rot(alpha, beta, gamma): - return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma) +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + +import dgl +import torch + + +def get_random_graph(N, num_edges_factor=18): + graph = dgl.transform.remove_self_loop(dgl.rand_graph(N, N * num_edges_factor)) + return graph + + +def assign_relative_pos(graph, coords): + src, dst = graph.edges() + graph.edata['rel_pos'] = coords[src] - coords[dst] + return graph + + +def get_max_diff(a, b): + return (a - b).abs().max().item() + + +def rot_z(gamma): + return torch.tensor([ + [torch.cos(gamma), -torch.sin(gamma), 0], + [torch.sin(gamma), torch.cos(gamma), 0], + [0, 0, 1] + ], dtype=gamma.dtype) + + +def rot_y(beta): + return torch.tensor([ + [torch.cos(beta), 0, torch.sin(beta)], + [0, 1, 0], + [-torch.sin(beta), 0, torch.cos(beta)] + ], dtype=beta.dtype) + + +def rot(alpha, beta, gamma): + return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)