From 22d6621dcdabf486c8233651bd70523fa8f2963d Mon Sep 17 00:00:00 2001 From: Alexandre Milesi Date: Mon, 1 Nov 2021 17:49:17 +0100 Subject: [PATCH] [SE3Transformer/DGLPyT] Better low memory mode --- .../DrugDiscovery/SE3Transformer/Dockerfile | 3 +- .../DrugDiscovery/SE3Transformer/README.md | 18 +- .../data_loading/data_module.py | 3 +- .../se3_transformer/model/__init__.py | 4 +- .../se3_transformer/model/fiber.py | 288 ++++---- .../se3_transformer/model/layers/__init__.py | 8 +- .../se3_transformer/model/layers/attention.py | 361 ++++----- .../model/layers/convolution.py | 699 +++++++++--------- .../se3_transformer/model/layers/linear.py | 118 +-- .../se3_transformer/model/layers/norm.py | 166 ++--- .../se3_transformer/model/layers/pooling.py | 106 +-- .../se3_transformer/model/transformer.py | 445 +++++------ .../se3_transformer/runtime/arguments.py | 142 ++-- .../se3_transformer/runtime/inference.py | 8 +- .../se3_transformer/runtime/training.py | 481 ++++++------ .../DrugDiscovery/SE3Transformer/setup.py | 4 +- .../SE3Transformer/tests/test_equivariance.py | 208 +++--- .../SE3Transformer/tests/utils.py | 120 +-- README.md | 2 +- 19 files changed, 1610 insertions(+), 1574 deletions(-) diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/Dockerfile b/DGLPyTorch/DrugDiscovery/SE3Transformer/Dockerfile index b4e4eec2..ab841025 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/Dockerfile +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/Dockerfile @@ -42,7 +42,6 @@ RUN make -j8 FROM ${FROM_IMAGE_NAME} -RUN rm -rf /workspace/* WORKDIR /workspace/se3-transformer # copy built DGL and install it @@ -55,3 +54,5 @@ ADD . . ENV DGLBACKEND=pytorch ENV OMP_NUM_THREADS=1 + + diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/README.md b/DGLPyTorch/DrugDiscovery/SE3Transformer/README.md index f856d1ad..ab716a31 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/README.md +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/README.md @@ -126,7 +126,13 @@ The following performance optimizations were implemented in this model: - The layout (order of dimensions) of the bases tensors is optimized to avoid copies to contiguous memory in the downstream TFN layers - When Tensor Cores are available, and the output feature dimension of computed bases is odd, then it is padded with zeros to make more effective use of Tensor Cores (AMP and TF32 precisions) - Multiple levels of fusion for TFN convolutions (and radial profiles) are provided and automatically used when conditions are met -- A low-memory mode is provided that will trade throughput for less memory use (`--low_memory`) +- A low-memory mode is provided that will trade throughput for less memory use (`--low_memory`). Overview of memory savings over the official implementation (batch size 100), depending on the precision and the low memory mode: + + | | FP32 | AMP + |---|-----------------------|-------------------------- + |`--low_memory false` (default) | 4.7x | 7.1x + |`--low_memory true` | 29.4x | 43.6x + **Self-attention optimizations** @@ -358,7 +364,7 @@ The complete list of the available parameters for the `training.py` script conta - `--pooling`: Type of graph pooling (default: `max`) - `--norm`: Apply a normalization layer after each attention block (default: `false`) - `--use_layer_norm`: Apply layer normalization between MLP layers (default: `false`) -- `--low_memory`: If true, will use fused ops that are slower but use less memory (expect 25 percent less memory). Only has an effect if AMP is enabled on NVIDIA Volta GPUs or if running on Ampere GPUs (default: `false`) +- `--low_memory`: If true, will use ops that are slower but use less memory (default: `false`) - `--num_degrees`: Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1] (default: `4`) - `--num_channels`: Number of channels for the hidden features (default: `32`) @@ -407,7 +413,8 @@ The training script is `se3_transformer/runtime/training.py`, to be run as a mod By default, the resulting logs are stored in `/results/`. This can be changed with `--log_dir`. -You can connect your existing Weights & Biases account by setting the `WANDB_API_KEY` environment variable. +You can connect your existing Weights & Biases account by setting the WANDB_API_KEY environment variable, and enabling the `--wandb` flag. +If no API key is set, `--wandb` will log the run anonymously to Weights & Biases. **Checkpoints** @@ -573,6 +580,11 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic ### Changelog +November 2021: +- Improved low memory mode to give further 6x memory savings +- Disabled W&B logging by default +- Fixed persistent workers when using one data loading process + October 2021: - Updated README performance tables - Fixed shape mismatch when using partially fused TFNs per output degree diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/data_loading/data_module.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/data_loading/data_module.py index 1047d41d..0cfb5613 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/data_loading/data_module.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/data_loading/data_module.py @@ -46,7 +46,8 @@ class DataModule(ABC): if dist.is_initialized(): dist.barrier(device_ids=[get_local_rank()]) - self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': True, **dataloader_kwargs} + self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': dataloader_kwargs.get('num_workers', 0) > 0, + **dataloader_kwargs} self.ds_train, self.ds_val, self.ds_test = None, None, None def prepare_data(self): diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/__init__.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/__init__.py index 6d796b2d..628d01e9 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/__init__.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/__init__.py @@ -1,2 +1,2 @@ -from .transformer import SE3Transformer, SE3TransformerPooled -from .fiber import Fiber +from .transformer import SE3Transformer, SE3TransformerPooled +from .fiber import Fiber diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/fiber.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/fiber.py index 50d5c6da..38db33b0 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/fiber.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/fiber.py @@ -1,144 +1,144 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - - -from collections import namedtuple -from itertools import product -from typing import Dict - -import torch -from torch import Tensor - -from se3_transformer.runtime.utils import degree_to_dim - -FiberEl = namedtuple('FiberEl', ['degree', 'channels']) - - -class Fiber(dict): - """ - Describes the structure of some set of features. - Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1. - Type-0 features: invariant scalars - Type-1 features: equivariant 3D vectors - Type-2 features: equivariant symmetric traceless matrices - ... - - As inputs to a SE3 layer, there can be many features of the same types, and many features of different types. - The 'multiplicity' or 'number of channels' is the number of features of a given type. - This class puts together all the degrees and their multiplicities in order to describe - the inputs, outputs or hidden features of SE3 layers. - """ - - def __init__(self, structure): - if isinstance(structure, dict): - structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])] - elif not isinstance(structure[0], FiberEl): - structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1]))) - self.structure = structure - super().__init__({d: m for d, m in self.structure}) - - @property - def degrees(self): - return sorted([t.degree for t in self.structure]) - - @property - def channels(self): - return [self[d] for d in self.degrees] - - @property - def num_features(self): - """ Size of the resulting tensor if all features were concatenated together """ - return sum(t.channels * degree_to_dim(t.degree) for t in self.structure) - - @staticmethod - def create(num_degrees: int, num_channels: int): - """ Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """ - return Fiber([(degree, num_channels) for degree in range(num_degrees)]) - - @staticmethod - def from_features(feats: Dict[str, Tensor]): - """ Infer the Fiber structure from a feature dict """ - structure = {} - for k, v in feats.items(): - degree = int(k) - assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)' - assert v.shape[-1] == degree_to_dim(degree) - structure[degree] = v.shape[-2] - return Fiber(structure) - - def __getitem__(self, degree: int): - """ fiber[degree] returns the multiplicity for this degree """ - return dict(self.structure).get(degree, 0) - - def __iter__(self): - """ Iterate over namedtuples (degree, channels) """ - return iter(self.structure) - - def __mul__(self, other): - """ - If other in an int, multiplies all the multiplicities by other. - If other is a fiber, returns the cartesian product. - """ - if isinstance(other, Fiber): - return product(self.structure, other.structure) - elif isinstance(other, int): - return Fiber({t.degree: t.channels * other for t in self.structure}) - - def __add__(self, other): - """ - If other in an int, add other to all the multiplicities. - If other is a fiber, add the multiplicities of the fibers together. - """ - if isinstance(other, Fiber): - return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure}) - elif isinstance(other, int): - return Fiber({t.degree: t.channels + other for t in self.structure}) - - def __repr__(self): - return str(self.structure) - - @staticmethod - def combine_max(f1, f2): - """ Combine two fiber by taking the maximum multiplicity for each degree in both fibers """ - new_dict = dict(f1.structure) - for k, m in f2.structure: - new_dict[k] = max(new_dict.get(k, 0), m) - - return Fiber(list(new_dict.items())) - - @staticmethod - def combine_selectively(f1, f2): - """ Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """ - # only use orders which occur in fiber f1 - new_dict = dict(f1.structure) - for k in f1.degrees: - if k in f2.degrees: - new_dict[k] += f2[k] - return Fiber(list(new_dict.items())) - - def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int): - # dict(N, num_channels, 2d+1) -> (N, num_heads, -1) - fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in - self.degrees] - fibers = torch.cat(fibers, -1) - return fibers +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + + +from collections import namedtuple +from itertools import product +from typing import Dict + +import torch +from torch import Tensor + +from se3_transformer.runtime.utils import degree_to_dim + +FiberEl = namedtuple('FiberEl', ['degree', 'channels']) + + +class Fiber(dict): + """ + Describes the structure of some set of features. + Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1. + Type-0 features: invariant scalars + Type-1 features: equivariant 3D vectors + Type-2 features: equivariant symmetric traceless matrices + ... + + As inputs to a SE3 layer, there can be many features of the same types, and many features of different types. + The 'multiplicity' or 'number of channels' is the number of features of a given type. + This class puts together all the degrees and their multiplicities in order to describe + the inputs, outputs or hidden features of SE3 layers. + """ + + def __init__(self, structure): + if isinstance(structure, dict): + structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])] + elif not isinstance(structure[0], FiberEl): + structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1]))) + self.structure = structure + super().__init__({d: m for d, m in self.structure}) + + @property + def degrees(self): + return sorted([t.degree for t in self.structure]) + + @property + def channels(self): + return [self[d] for d in self.degrees] + + @property + def num_features(self): + """ Size of the resulting tensor if all features were concatenated together """ + return sum(t.channels * degree_to_dim(t.degree) for t in self.structure) + + @staticmethod + def create(num_degrees: int, num_channels: int): + """ Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """ + return Fiber([(degree, num_channels) for degree in range(num_degrees)]) + + @staticmethod + def from_features(feats: Dict[str, Tensor]): + """ Infer the Fiber structure from a feature dict """ + structure = {} + for k, v in feats.items(): + degree = int(k) + assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)' + assert v.shape[-1] == degree_to_dim(degree) + structure[degree] = v.shape[-2] + return Fiber(structure) + + def __getitem__(self, degree: int): + """ fiber[degree] returns the multiplicity for this degree """ + return dict(self.structure).get(degree, 0) + + def __iter__(self): + """ Iterate over namedtuples (degree, channels) """ + return iter(self.structure) + + def __mul__(self, other): + """ + If other in an int, multiplies all the multiplicities by other. + If other is a fiber, returns the cartesian product. + """ + if isinstance(other, Fiber): + return product(self.structure, other.structure) + elif isinstance(other, int): + return Fiber({t.degree: t.channels * other for t in self.structure}) + + def __add__(self, other): + """ + If other in an int, add other to all the multiplicities. + If other is a fiber, add the multiplicities of the fibers together. + """ + if isinstance(other, Fiber): + return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure}) + elif isinstance(other, int): + return Fiber({t.degree: t.channels + other for t in self.structure}) + + def __repr__(self): + return str(self.structure) + + @staticmethod + def combine_max(f1, f2): + """ Combine two fiber by taking the maximum multiplicity for each degree in both fibers """ + new_dict = dict(f1.structure) + for k, m in f2.structure: + new_dict[k] = max(new_dict.get(k, 0), m) + + return Fiber(list(new_dict.items())) + + @staticmethod + def combine_selectively(f1, f2): + """ Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """ + # only use orders which occur in fiber f1 + new_dict = dict(f1.structure) + for k in f1.degrees: + if k in f2.degrees: + new_dict[k] += f2[k] + return Fiber(list(new_dict.items())) + + def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int): + # dict(N, num_channels, 2d+1) -> (N, num_heads, -1) + fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in + self.degrees] + fibers = torch.cat(fibers, -1) + return fibers diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/__init__.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/__init__.py index 2d6b14fb..9eb9e3ce 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/__init__.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/__init__.py @@ -1,5 +1,5 @@ -from .linear import LinearSE3 -from .norm import NormSE3 -from .pooling import GPooling -from .convolution import ConvSE3 +from .linear import LinearSE3 +from .norm import NormSE3 +from .pooling import GPooling +from .convolution import ConvSE3 from .attention import AttentionBlockSE3 \ No newline at end of file diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/attention.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/attention.py index 11a4fb16..541daa52 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/attention.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/attention.py @@ -1,180 +1,181 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - -import dgl -import numpy as np -import torch -import torch.nn as nn -from dgl import DGLGraph -from dgl.ops import edge_softmax -from torch import Tensor -from typing import Dict, Optional, Union - -from se3_transformer.model.fiber import Fiber -from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel -from se3_transformer.model.layers.linear import LinearSE3 -from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features -from torch.cuda.nvtx import range as nvtx_range - - -class AttentionSE3(nn.Module): - """ Multi-headed sparse graph self-attention (SE(3)-equivariant) """ - - def __init__( - self, - num_heads: int, - key_fiber: Fiber, - value_fiber: Fiber - ): - """ - :param num_heads: Number of attention heads - :param key_fiber: Fiber for the keys (and also for the queries) - :param value_fiber: Fiber for the values - """ - super().__init__() - self.num_heads = num_heads - self.key_fiber = key_fiber - self.value_fiber = value_fiber - - def forward( - self, - value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused) - key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused) - query: Dict[str, Tensor], # node features - graph: DGLGraph - ): - with nvtx_range('AttentionSE3'): - with nvtx_range('reshape keys and queries'): - if isinstance(key, Tensor): - # case where features of all types are fused - key = key.reshape(key.shape[0], self.num_heads, -1) - # need to reshape queries that way to keep the same layout as keys - out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1) - query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1) - else: - # features are not fused, need to fuse and reshape them - key = self.key_fiber.to_attention_heads(key, self.num_heads) - query = self.key_fiber.to_attention_heads(query, self.num_heads) - - with nvtx_range('attention dot product + softmax'): - # Compute attention weights (softmax of inner product between key and query) - edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1) - edge_weights = edge_weights / np.sqrt(self.key_fiber.num_features) - edge_weights = edge_softmax(graph, edge_weights) - edge_weights = edge_weights[..., None, None] - - with nvtx_range('weighted sum'): - if isinstance(value, Tensor): - # features of all types are fused - v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1]) - weights = edge_weights * v - feat_out = dgl.ops.copy_e_sum(graph, weights) - feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads - out = unfuse_features(feat_out, self.value_fiber.degrees) - else: - out = {} - for degree, channels in self.value_fiber: - v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads, - degree_to_dim(degree)) - weights = edge_weights * v - res = dgl.ops.copy_e_sum(graph, weights) - out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads - - return out - - -class AttentionBlockSE3(nn.Module): - """ Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """ - - def __init__( - self, - fiber_in: Fiber, - fiber_out: Fiber, - fiber_edge: Optional[Fiber] = None, - num_heads: int = 4, - channels_div: int = 2, - use_layer_norm: bool = False, - max_degree: bool = 4, - fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL, - **kwargs - ): - """ - :param fiber_in: Fiber describing the input features - :param fiber_out: Fiber describing the output features - :param fiber_edge: Fiber describing the edge features (node distances excluded) - :param num_heads: Number of attention heads - :param channels_div: Divide the channels by this integer for computing values - :param use_layer_norm: Apply layer normalization between MLP layers - :param max_degree: Maximum degree used in the bases computation - :param fuse_level: Maximum fuse level to use in TFN convolutions - """ - super().__init__() - if fiber_edge is None: - fiber_edge = Fiber({}) - self.fiber_in = fiber_in - # value_fiber has same structure as fiber_out but #channels divided by 'channels_div' - value_fiber = Fiber([(degree, channels // channels_div) for degree, channels in fiber_out]) - # key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber - # (queries are merely projected, hence degrees have to match input) - key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees]) - - self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge, - use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level, - allow_fused_output=True) - self.to_query = LinearSE3(fiber_in, key_query_fiber) - self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber) - self.project = LinearSE3(value_fiber + fiber_in, fiber_out) - - def forward( - self, - node_features: Dict[str, Tensor], - edge_features: Dict[str, Tensor], - graph: DGLGraph, - basis: Dict[str, Tensor] - ): - with nvtx_range('AttentionBlockSE3'): - with nvtx_range('keys / values'): - fused_key_value = self.to_key_value(node_features, edge_features, graph, basis) - key, value = self._get_key_value_from_fused(fused_key_value) - - with nvtx_range('queries'): - query = self.to_query(node_features) - - z = self.attention(value, key, query, graph) - z_concat = aggregate_residual(node_features, z, 'cat') - return self.project(z_concat) - - def _get_key_value_from_fused(self, fused_key_value): - # Extract keys and queries features from fused features - if isinstance(fused_key_value, Tensor): - # Previous layer was a fully fused convolution - value, key = torch.chunk(fused_key_value, chunks=2, dim=-2) - else: - key, value = {}, {} - for degree, feat in fused_key_value.items(): - if int(degree) in self.fiber_in.degrees: - value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2) - else: - value[degree] = feat - - return key, value +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + +import dgl +import numpy as np +import torch +import torch.nn as nn +from dgl import DGLGraph +from dgl.ops import edge_softmax +from torch import Tensor +from typing import Dict, Optional, Union + +from se3_transformer.model.fiber import Fiber +from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel +from se3_transformer.model.layers.linear import LinearSE3 +from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features +from torch.cuda.nvtx import range as nvtx_range + + +class AttentionSE3(nn.Module): + """ Multi-headed sparse graph self-attention (SE(3)-equivariant) """ + + def __init__( + self, + num_heads: int, + key_fiber: Fiber, + value_fiber: Fiber + ): + """ + :param num_heads: Number of attention heads + :param key_fiber: Fiber for the keys (and also for the queries) + :param value_fiber: Fiber for the values + """ + super().__init__() + self.num_heads = num_heads + self.key_fiber = key_fiber + self.value_fiber = value_fiber + + def forward( + self, + value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused) + key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused) + query: Dict[str, Tensor], # node features + graph: DGLGraph + ): + with nvtx_range('AttentionSE3'): + with nvtx_range('reshape keys and queries'): + if isinstance(key, Tensor): + # case where features of all types are fused + key = key.reshape(key.shape[0], self.num_heads, -1) + # need to reshape queries that way to keep the same layout as keys + out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1) + query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1) + else: + # features are not fused, need to fuse and reshape them + key = self.key_fiber.to_attention_heads(key, self.num_heads) + query = self.key_fiber.to_attention_heads(query, self.num_heads) + + with nvtx_range('attention dot product + softmax'): + # Compute attention weights (softmax of inner product between key and query) + edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1) + edge_weights = edge_weights / np.sqrt(self.key_fiber.num_features) + edge_weights = edge_softmax(graph, edge_weights) + edge_weights = edge_weights[..., None, None] + + with nvtx_range('weighted sum'): + if isinstance(value, Tensor): + # features of all types are fused + v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1]) + weights = edge_weights * v + feat_out = dgl.ops.copy_e_sum(graph, weights) + feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads + out = unfuse_features(feat_out, self.value_fiber.degrees) + else: + out = {} + for degree, channels in self.value_fiber: + v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads, + degree_to_dim(degree)) + weights = edge_weights * v + res = dgl.ops.copy_e_sum(graph, weights) + out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads + + return out + + +class AttentionBlockSE3(nn.Module): + """ Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """ + + def __init__( + self, + fiber_in: Fiber, + fiber_out: Fiber, + fiber_edge: Optional[Fiber] = None, + num_heads: int = 4, + channels_div: int = 2, + use_layer_norm: bool = False, + max_degree: bool = 4, + fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL, + low_memory: bool = False, + **kwargs + ): + """ + :param fiber_in: Fiber describing the input features + :param fiber_out: Fiber describing the output features + :param fiber_edge: Fiber describing the edge features (node distances excluded) + :param num_heads: Number of attention heads + :param channels_div: Divide the channels by this integer for computing values + :param use_layer_norm: Apply layer normalization between MLP layers + :param max_degree: Maximum degree used in the bases computation + :param fuse_level: Maximum fuse level to use in TFN convolutions + """ + super().__init__() + if fiber_edge is None: + fiber_edge = Fiber({}) + self.fiber_in = fiber_in + # value_fiber has same structure as fiber_out but #channels divided by 'channels_div' + value_fiber = Fiber([(degree, channels // channels_div) for degree, channels in fiber_out]) + # key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber + # (queries are merely projected, hence degrees have to match input) + key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees]) + + self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge, + use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level, + allow_fused_output=True, low_memory=low_memory) + self.to_query = LinearSE3(fiber_in, key_query_fiber) + self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber) + self.project = LinearSE3(value_fiber + fiber_in, fiber_out) + + def forward( + self, + node_features: Dict[str, Tensor], + edge_features: Dict[str, Tensor], + graph: DGLGraph, + basis: Dict[str, Tensor] + ): + with nvtx_range('AttentionBlockSE3'): + with nvtx_range('keys / values'): + fused_key_value = self.to_key_value(node_features, edge_features, graph, basis) + key, value = self._get_key_value_from_fused(fused_key_value) + + with nvtx_range('queries'): + query = self.to_query(node_features) + + z = self.attention(value, key, query, graph) + z_concat = aggregate_residual(node_features, z, 'cat') + return self.project(z_concat) + + def _get_key_value_from_fused(self, fused_key_value): + # Extract keys and queries features from fused features + if isinstance(fused_key_value, Tensor): + # Previous layer was a fully fused convolution + value, key = torch.chunk(fused_key_value, chunks=2, dim=-2) + else: + key, value = {}, {} + for degree, feat in fused_key_value.items(): + if int(degree) in self.fiber_in.degrees: + value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2) + else: + value[degree] = feat + + return key, value diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/convolution.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/convolution.py index 0cab99e1..7bf9691b 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/convolution.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/convolution.py @@ -1,345 +1,354 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - -from enum import Enum -from itertools import product -from typing import Dict - -import dgl -import numpy as np -import torch -import torch.nn as nn -from dgl import DGLGraph -from torch import Tensor -from torch.cuda.nvtx import range as nvtx_range - -from se3_transformer.model.fiber import Fiber -from se3_transformer.runtime.utils import degree_to_dim, unfuse_features - - -class ConvSE3FuseLevel(Enum): - """ - Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met. - If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered. - A higher level means faster training, but also more memory usage. - If you are tight on memory and want to feed large inputs to the network, choose a low value. - If you want to train fast, choose a high value. - Recommended value is FULL with AMP. - - Fully fused TFN convolutions requirements: - - all input channels are the same - - all output channels are the same - - input degrees span the range [0, ..., max_degree] - - output degrees span the range [0, ..., max_degree] - - Partially fused TFN convolutions requirements: - * For fusing by output degree: - - all input channels are the same - - input degrees span the range [0, ..., max_degree] - * For fusing by input degree: - - all output channels are the same - - output degrees span the range [0, ..., max_degree] - - Original TFN pairwise convolutions: no requirements - """ - - FULL = 2 - PARTIAL = 1 - NONE = 0 - - -class RadialProfile(nn.Module): - """ - Radial profile function. - Outputs weights used to weigh basis matrices in order to get convolution kernels. - In TFN notation: $R^{l,k}$ - In SE(3)-Transformer notation: $\phi^{l,k}$ - - Note: - In the original papers, this function only depends on relative node distances ||x||. - Here, we allow this function to also take as input additional invariant edge features. - This does not break equivariance and adds expressive power to the model. - - Diagram: - invariant edge features (node distances included) ───> MLP layer (shared across edges) ───> radial weights - """ - - def __init__( - self, - num_freq: int, - channels_in: int, - channels_out: int, - edge_dim: int = 1, - mid_dim: int = 32, - use_layer_norm: bool = False - ): - """ - :param num_freq: Number of frequencies - :param channels_in: Number of input channels - :param channels_out: Number of output channels - :param edge_dim: Number of invariant edge features (input to the radial function) - :param mid_dim: Size of the hidden MLP layers - :param use_layer_norm: Apply layer normalization between MLP layers - """ - super().__init__() - modules = [ - nn.Linear(edge_dim, mid_dim), - nn.LayerNorm(mid_dim) if use_layer_norm else None, - nn.ReLU(), - nn.Linear(mid_dim, mid_dim), - nn.LayerNorm(mid_dim) if use_layer_norm else None, - nn.ReLU(), - nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False) - ] - - self.net = nn.Sequential(*[m for m in modules if m is not None]) - - def forward(self, features: Tensor) -> Tensor: - return self.net(features) - - -class VersatileConvSE3(nn.Module): - """ - Building block for TFN convolutions. - This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions. - """ - - def __init__(self, - freq_sum: int, - channels_in: int, - channels_out: int, - edge_dim: int, - use_layer_norm: bool, - fuse_level: ConvSE3FuseLevel): - super().__init__() - self.freq_sum = freq_sum - self.channels_out = channels_out - self.channels_in = channels_in - self.fuse_level = fuse_level - self.radial_func = RadialProfile(num_freq=freq_sum, - channels_in=channels_in, - channels_out=channels_out, - edge_dim=edge_dim, - use_layer_norm=use_layer_norm) - - def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor): - with nvtx_range(f'VersatileConvSE3'): - num_edges = features.shape[0] - in_dim = features.shape[2] - with nvtx_range(f'RadialProfile'): - radial_weights = self.radial_func(invariant_edge_feats) \ - .view(-1, self.channels_out, self.channels_in * self.freq_sum) - - if basis is not None: - # This block performs the einsum n i l, n o i f, n l f k -> n o k - basis_view = basis.view(num_edges, in_dim, -1) - tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1]) - return radial_weights @ tmp - else: - # k = l = 0 non-fused case - return radial_weights @ features - - -class ConvSE3(nn.Module): - """ - SE(3)-equivariant graph convolution (Tensor Field Network convolution). - This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance. - Features of different degrees interact together to produce output features. - - Note 1: - The option is given to not pool the output. This means that the convolution sum over neighbors will not be - done, and the returned features will be edge features instead of node features. - - Note 2: - Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0. - Input edge features are concatenated with input source node features before the kernel is applied. - """ - - def __init__( - self, - fiber_in: Fiber, - fiber_out: Fiber, - fiber_edge: Fiber, - pool: bool = True, - use_layer_norm: bool = False, - self_interaction: bool = False, - max_degree: int = 4, - fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL, - allow_fused_output: bool = False - ): - """ - :param fiber_in: Fiber describing the input features - :param fiber_out: Fiber describing the output features - :param fiber_edge: Fiber describing the edge features (node distances excluded) - :param pool: If True, compute final node features by averaging incoming edge features - :param use_layer_norm: Apply layer normalization between MLP layers - :param self_interaction: Apply self-interaction of nodes - :param max_degree: Maximum degree used in the bases computation - :param fuse_level: Maximum fuse level to use in TFN convolutions - :param allow_fused_output: Allow the module to output a fused representation of features - """ - super().__init__() - self.pool = pool - self.fiber_in = fiber_in - self.fiber_out = fiber_out - self.self_interaction = self_interaction - self.max_degree = max_degree - self.allow_fused_output = allow_fused_output - - # channels_in: account for the concatenation of edge features - channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in]) - channels_out_set = set([f.channels for f in self.fiber_out]) - unique_channels_in = (len(channels_in_set) == 1) - unique_channels_out = (len(channels_out_set) == 1) - degrees_up_to_max = list(range(max_degree + 1)) - common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm) - - if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \ - unique_channels_in and fiber_in.degrees == degrees_up_to_max and \ - unique_channels_out and fiber_out.degrees == degrees_up_to_max: - # Single fused convolution - self.used_fuse_level = ConvSE3FuseLevel.FULL - - sum_freq = sum([ - degree_to_dim(min(d_in, d_out)) - for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max) - ]) - - self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0], - fuse_level=self.used_fuse_level, **common_args) - - elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \ - unique_channels_in and fiber_in.degrees == degrees_up_to_max: - # Convolutions fused per output degree - self.used_fuse_level = ConvSE3FuseLevel.PARTIAL - self.conv_out = nn.ModuleDict() - for d_out, c_out in fiber_out: - sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees]) - self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out, - fuse_level=self.used_fuse_level, **common_args) - - elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \ - unique_channels_out and fiber_out.degrees == degrees_up_to_max: - # Convolutions fused per input degree - self.used_fuse_level = ConvSE3FuseLevel.PARTIAL - self.conv_in = nn.ModuleDict() - for d_in, c_in in fiber_in: - sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees]) - channels_in_new = c_in + fiber_edge[d_in] * (d_in > 0) - self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, channels_in_new, list(channels_out_set)[0], - fuse_level=self.used_fuse_level, **common_args) - else: - # Use pairwise TFN convolutions - self.used_fuse_level = ConvSE3FuseLevel.NONE - self.conv = nn.ModuleDict() - for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out): - dict_key = f'{degree_in},{degree_out}' - channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0) - sum_freq = degree_to_dim(min(degree_in, degree_out)) - self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out, - fuse_level=self.used_fuse_level, **common_args) - - if self_interaction: - self.to_kernel_self = nn.ParameterDict() - for degree_out, channels_out in fiber_out: - if fiber_in[degree_out]: - self.to_kernel_self[str(degree_out)] = nn.Parameter( - torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out])) - - def _try_unpad(self, feature, basis): - # Account for padded basis - if basis is not None: - out_dim = basis.shape[-1] - out_dim += out_dim % 2 - 1 - return feature[..., :out_dim] - else: - return feature - - def forward( - self, - node_feats: Dict[str, Tensor], - edge_feats: Dict[str, Tensor], - graph: DGLGraph, - basis: Dict[str, Tensor] - ): - with nvtx_range(f'ConvSE3'): - invariant_edge_feats = edge_feats['0'].squeeze(-1) - src, dst = graph.edges() - out = {} - in_features = [] - - # Fetch all input features from edge and node features - for degree_in in self.fiber_in.degrees: - src_node_features = node_feats[str(degree_in)][src] - if degree_in > 0 and str(degree_in) in edge_feats: - # Handle edge features of any type by concatenating them to node features - src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1) - in_features.append(src_node_features) - - if self.used_fuse_level == ConvSE3FuseLevel.FULL: - in_features_fused = torch.cat(in_features, dim=-1) - out = self.conv(in_features_fused, invariant_edge_feats, basis['fully_fused']) - - if not self.allow_fused_output or self.self_interaction or self.pool: - out = unfuse_features(out, self.fiber_out.degrees) - - elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'): - in_features_fused = torch.cat(in_features, dim=-1) - for degree_out in self.fiber_out.degrees: - basis_used = basis[f'out{degree_out}_fused'] - out[str(degree_out)] = self._try_unpad( - self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats, basis_used), - basis_used) - - elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'): - out = 0 - for degree_in, feature in zip(self.fiber_in.degrees, in_features): - out = out + self.conv_in[str(degree_in)](feature, invariant_edge_feats, basis[f'in{degree_in}_fused']) - if not self.allow_fused_output or self.self_interaction or self.pool: - out = unfuse_features(out, self.fiber_out.degrees) - else: - # Fallback to pairwise TFN convolutions - for degree_out in self.fiber_out.degrees: - out_feature = 0 - for degree_in, feature in zip(self.fiber_in.degrees, in_features): - dict_key = f'{degree_in},{degree_out}' - basis_used = basis.get(dict_key, None) - out_feature = out_feature + self._try_unpad( - self.conv[dict_key](feature, invariant_edge_feats, basis_used), - basis_used) - out[str(degree_out)] = out_feature - - for degree_out in self.fiber_out.degrees: - if self.self_interaction and str(degree_out) in self.to_kernel_self: - with nvtx_range(f'self interaction'): - dst_features = node_feats[str(degree_out)][dst] - kernel_self = self.to_kernel_self[str(degree_out)] - out[str(degree_out)] = out[str(degree_out)] + kernel_self @ dst_features - - if self.pool: - with nvtx_range(f'pooling'): - if isinstance(out, dict): - out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)]) - else: - out = dgl.ops.copy_e_sum(graph, out) - return out +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + +from enum import Enum +from itertools import product +from typing import Dict + +import dgl +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint +from dgl import DGLGraph +from torch import Tensor +from torch.cuda.nvtx import range as nvtx_range + +from se3_transformer.model.fiber import Fiber +from se3_transformer.runtime.utils import degree_to_dim, unfuse_features + + +class ConvSE3FuseLevel(Enum): + """ + Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met. + If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered. + A higher level means faster training, but also more memory usage. + If you are tight on memory and want to feed large inputs to the network, choose a low value. + If you want to train fast, choose a high value. + Recommended value is FULL with AMP. + + Fully fused TFN convolutions requirements: + - all input channels are the same + - all output channels are the same + - input degrees span the range [0, ..., max_degree] + - output degrees span the range [0, ..., max_degree] + + Partially fused TFN convolutions requirements: + * For fusing by output degree: + - all input channels are the same + - input degrees span the range [0, ..., max_degree] + * For fusing by input degree: + - all output channels are the same + - output degrees span the range [0, ..., max_degree] + + Original TFN pairwise convolutions: no requirements + """ + + FULL = 2 + PARTIAL = 1 + NONE = 0 + + +class RadialProfile(nn.Module): + """ + Radial profile function. + Outputs weights used to weigh basis matrices in order to get convolution kernels. + In TFN notation: $R^{l,k}$ + In SE(3)-Transformer notation: $\phi^{l,k}$ + + Note: + In the original papers, this function only depends on relative node distances ||x||. + Here, we allow this function to also take as input additional invariant edge features. + This does not break equivariance and adds expressive power to the model. + + Diagram: + invariant edge features (node distances included) ───> MLP layer (shared across edges) ───> radial weights + """ + + def __init__( + self, + num_freq: int, + channels_in: int, + channels_out: int, + edge_dim: int = 1, + mid_dim: int = 32, + use_layer_norm: bool = False + ): + """ + :param num_freq: Number of frequencies + :param channels_in: Number of input channels + :param channels_out: Number of output channels + :param edge_dim: Number of invariant edge features (input to the radial function) + :param mid_dim: Size of the hidden MLP layers + :param use_layer_norm: Apply layer normalization between MLP layers + """ + super().__init__() + modules = [ + nn.Linear(edge_dim, mid_dim), + nn.LayerNorm(mid_dim) if use_layer_norm else None, + nn.ReLU(), + nn.Linear(mid_dim, mid_dim), + nn.LayerNorm(mid_dim) if use_layer_norm else None, + nn.ReLU(), + nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False) + ] + + self.net = nn.Sequential(*[m for m in modules if m is not None]) + + def forward(self, features: Tensor) -> Tensor: + return self.net(features) + + +class VersatileConvSE3(nn.Module): + """ + Building block for TFN convolutions. + This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions. + """ + + def __init__(self, + freq_sum: int, + channels_in: int, + channels_out: int, + edge_dim: int, + use_layer_norm: bool, + fuse_level: ConvSE3FuseLevel): + super().__init__() + self.freq_sum = freq_sum + self.channels_out = channels_out + self.channels_in = channels_in + self.fuse_level = fuse_level + self.radial_func = RadialProfile(num_freq=freq_sum, + channels_in=channels_in, + channels_out=channels_out, + edge_dim=edge_dim, + use_layer_norm=use_layer_norm) + + def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor): + with nvtx_range(f'VersatileConvSE3'): + num_edges = features.shape[0] + in_dim = features.shape[2] + with nvtx_range(f'RadialProfile'): + radial_weights = self.radial_func(invariant_edge_feats) \ + .view(-1, self.channels_out, self.channels_in * self.freq_sum) + + if basis is not None: + # This block performs the einsum n i l, n o i f, n l f k -> n o k + basis_view = basis.view(num_edges, in_dim, -1) + tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1]) + return radial_weights @ tmp + else: + # k = l = 0 non-fused case + return radial_weights @ features + + +class ConvSE3(nn.Module): + """ + SE(3)-equivariant graph convolution (Tensor Field Network convolution). + This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance. + Features of different degrees interact together to produce output features. + + Note 1: + The option is given to not pool the output. This means that the convolution sum over neighbors will not be + done, and the returned features will be edge features instead of node features. + + Note 2: + Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0. + Input edge features are concatenated with input source node features before the kernel is applied. + """ + + def __init__( + self, + fiber_in: Fiber, + fiber_out: Fiber, + fiber_edge: Fiber, + pool: bool = True, + use_layer_norm: bool = False, + self_interaction: bool = False, + max_degree: int = 4, + fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL, + allow_fused_output: bool = False, + low_memory: bool = False + ): + """ + :param fiber_in: Fiber describing the input features + :param fiber_out: Fiber describing the output features + :param fiber_edge: Fiber describing the edge features (node distances excluded) + :param pool: If True, compute final node features by averaging incoming edge features + :param use_layer_norm: Apply layer normalization between MLP layers + :param self_interaction: Apply self-interaction of nodes + :param max_degree: Maximum degree used in the bases computation + :param fuse_level: Maximum fuse level to use in TFN convolutions + :param allow_fused_output: Allow the module to output a fused representation of features + """ + super().__init__() + self.pool = pool + self.fiber_in = fiber_in + self.fiber_out = fiber_out + self.self_interaction = self_interaction + self.max_degree = max_degree + self.allow_fused_output = allow_fused_output + self.conv_checkpoint = torch.utils.checkpoint.checkpoint if low_memory else lambda m, *x: m(*x) + + # channels_in: account for the concatenation of edge features + channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in]) + channels_out_set = set([f.channels for f in self.fiber_out]) + unique_channels_in = (len(channels_in_set) == 1) + unique_channels_out = (len(channels_out_set) == 1) + degrees_up_to_max = list(range(max_degree + 1)) + common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm) + + if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \ + unique_channels_in and fiber_in.degrees == degrees_up_to_max and \ + unique_channels_out and fiber_out.degrees == degrees_up_to_max: + # Single fused convolution + self.used_fuse_level = ConvSE3FuseLevel.FULL + + sum_freq = sum([ + degree_to_dim(min(d_in, d_out)) + for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max) + ]) + + self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0], + fuse_level=self.used_fuse_level, **common_args) + + elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \ + unique_channels_in and fiber_in.degrees == degrees_up_to_max: + # Convolutions fused per output degree + self.used_fuse_level = ConvSE3FuseLevel.PARTIAL + self.conv_out = nn.ModuleDict() + for d_out, c_out in fiber_out: + sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees]) + self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out, + fuse_level=self.used_fuse_level, **common_args) + + elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \ + unique_channels_out and fiber_out.degrees == degrees_up_to_max: + # Convolutions fused per input degree + self.used_fuse_level = ConvSE3FuseLevel.PARTIAL + self.conv_in = nn.ModuleDict() + for d_in, c_in in fiber_in: + channels_in_new = c_in + fiber_edge[d_in] * (d_in > 0) + sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees]) + self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, channels_in_new, list(channels_out_set)[0], + fuse_level=self.used_fuse_level, **common_args) + else: + # Use pairwise TFN convolutions + self.used_fuse_level = ConvSE3FuseLevel.NONE + self.conv = nn.ModuleDict() + for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out): + dict_key = f'{degree_in},{degree_out}' + channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0) + sum_freq = degree_to_dim(min(degree_in, degree_out)) + self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out, + fuse_level=self.used_fuse_level, **common_args) + + if self_interaction: + self.to_kernel_self = nn.ParameterDict() + for degree_out, channels_out in fiber_out: + if fiber_in[degree_out]: + self.to_kernel_self[str(degree_out)] = nn.Parameter( + torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out])) + + def _try_unpad(self, feature, basis): + # Account for padded basis + if basis is not None: + out_dim = basis.shape[-1] + out_dim += out_dim % 2 - 1 + return feature[..., :out_dim] + else: + return feature + + def forward( + self, + node_feats: Dict[str, Tensor], + edge_feats: Dict[str, Tensor], + graph: DGLGraph, + basis: Dict[str, Tensor] + ): + with nvtx_range(f'ConvSE3'): + invariant_edge_feats = edge_feats['0'].squeeze(-1) + src, dst = graph.edges() + out = {} + in_features = [] + + # Fetch all input features from edge and node features + for degree_in in self.fiber_in.degrees: + src_node_features = node_feats[str(degree_in)][src] + if degree_in > 0 and str(degree_in) in edge_feats: + # Handle edge features of any type by concatenating them to node features + src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1) + in_features.append(src_node_features) + + if self.used_fuse_level == ConvSE3FuseLevel.FULL: + in_features_fused = torch.cat(in_features, dim=-1) + out = self.conv_checkpoint( + self.conv, in_features_fused, invariant_edge_feats, basis['fully_fused'] + ) + + if not self.allow_fused_output or self.self_interaction or self.pool: + out = unfuse_features(out, self.fiber_out.degrees) + + elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'): + in_features_fused = torch.cat(in_features, dim=-1) + for degree_out in self.fiber_out.degrees: + basis_used = basis[f'out{degree_out}_fused'] + out[str(degree_out)] = self._try_unpad( + self.conv_checkpoint( + self.conv_out[str(degree_out)], in_features_fused, invariant_edge_feats, basis_used + ), basis_used) + + elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'): + out = 0 + for degree_in, feature in zip(self.fiber_in.degrees, in_features): + out = out + self.conv_checkpoint( + self.conv_in[str(degree_in)], feature, invariant_edge_feats, basis[f'in{degree_in}_fused'] + ) + if not self.allow_fused_output or self.self_interaction or self.pool: + out = unfuse_features(out, self.fiber_out.degrees) + else: + # Fallback to pairwise TFN convolutions + for degree_out in self.fiber_out.degrees: + out_feature = 0 + for degree_in, feature in zip(self.fiber_in.degrees, in_features): + dict_key = f'{degree_in},{degree_out}' + basis_used = basis.get(dict_key, None) + out_feature = out_feature + self._try_unpad( + self.conv_checkpoint( + self.conv[dict_key], feature, invariant_edge_feats, basis_used + ), basis_used) + out[str(degree_out)] = out_feature + + for degree_out in self.fiber_out.degrees: + if self.self_interaction and str(degree_out) in self.to_kernel_self: + with nvtx_range(f'self interaction'): + dst_features = node_feats[str(degree_out)][dst] + kernel_self = self.to_kernel_self[str(degree_out)] + out[str(degree_out)] = out[str(degree_out)] + kernel_self @ dst_features + + if self.pool: + with nvtx_range(f'pooling'): + if isinstance(out, dict): + out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)]) + else: + out = dgl.ops.copy_e_sum(graph, out) + return out diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/linear.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/linear.py index c138d897..f720d77e 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/linear.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/linear.py @@ -1,59 +1,59 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - - -from typing import Dict - -import numpy as np -import torch -import torch.nn as nn -from torch import Tensor - -from se3_transformer.model.fiber import Fiber - - -class LinearSE3(nn.Module): - """ - Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution. - Maps a fiber to a fiber with the same degrees (channels may be different). - No interaction between degrees, but interaction between channels. - - type-0 features (C_0 channels) ────> Linear(bias=False) ────> type-0 features (C'_0 channels) - type-1 features (C_1 channels) ────> Linear(bias=False) ────> type-1 features (C'_1 channels) - : - type-k features (C_k channels) ────> Linear(bias=False) ────> type-k features (C'_k channels) - """ - - def __init__(self, fiber_in: Fiber, fiber_out: Fiber): - super().__init__() - self.weights = nn.ParameterDict({ - str(degree_out): nn.Parameter( - torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out])) - for degree_out, channels_out in fiber_out - }) - - def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]: - return { - degree: self.weights[degree] @ features[degree] - for degree, weight in self.weights.items() - } +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + + +from typing import Dict + +import numpy as np +import torch +import torch.nn as nn +from torch import Tensor + +from se3_transformer.model.fiber import Fiber + + +class LinearSE3(nn.Module): + """ + Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution. + Maps a fiber to a fiber with the same degrees (channels may be different). + No interaction between degrees, but interaction between channels. + + type-0 features (C_0 channels) ────> Linear(bias=False) ────> type-0 features (C'_0 channels) + type-1 features (C_1 channels) ────> Linear(bias=False) ────> type-1 features (C'_1 channels) + : + type-k features (C_k channels) ────> Linear(bias=False) ────> type-k features (C'_k channels) + """ + + def __init__(self, fiber_in: Fiber, fiber_out: Fiber): + super().__init__() + self.weights = nn.ParameterDict({ + str(degree_out): nn.Parameter( + torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out])) + for degree_out, channels_out in fiber_out + }) + + def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]: + return { + degree: self.weights[degree] @ features[degree] + for degree, weight in self.weights.items() + } diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/norm.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/norm.py index 71c753f6..acbe23d7 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/norm.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/norm.py @@ -1,83 +1,83 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - - -from typing import Dict - -import torch -import torch.nn as nn -from torch import Tensor -from torch.cuda.nvtx import range as nvtx_range - -from se3_transformer.model.fiber import Fiber - - -class NormSE3(nn.Module): - """ - Norm-based SE(3)-equivariant nonlinearity. - - ┌──> feature_norm ──> LayerNorm() ──> ReLU() ──┐ - feature_in ──┤ * ──> feature_out - └──> feature_phase ────────────────────────────┘ - """ - - NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16 - - def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()): - super().__init__() - self.fiber = fiber - self.nonlinearity = nonlinearity - - if len(set(fiber.channels)) == 1: - # Fuse all the layer normalizations into a group normalization - self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels)) - else: - # Use multiple layer normalizations - self.layer_norms = nn.ModuleDict({ - str(degree): nn.LayerNorm(channels) - for degree, channels in fiber - }) - - def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]: - with nvtx_range('NormSE3'): - output = {} - if hasattr(self, 'group_norm'): - # Compute per-degree norms of features - norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP) - for d in self.fiber.degrees] - fused_norms = torch.cat(norms, dim=-2) - - # Transform the norms only - new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1) - new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2) - - # Scale features to the new norms - for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees): - output[str(d)] = features[str(d)] / norm * new_norm - else: - for degree, feat in features.items(): - norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP) - new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1)) - output[degree] = new_norm * feat / norm - - return output +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + + +from typing import Dict + +import torch +import torch.nn as nn +from torch import Tensor +from torch.cuda.nvtx import range as nvtx_range + +from se3_transformer.model.fiber import Fiber + + +class NormSE3(nn.Module): + """ + Norm-based SE(3)-equivariant nonlinearity. + + ┌──> feature_norm ──> LayerNorm() ──> ReLU() ──┐ + feature_in ──┤ * ──> feature_out + └──> feature_phase ────────────────────────────┘ + """ + + NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16 + + def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()): + super().__init__() + self.fiber = fiber + self.nonlinearity = nonlinearity + + if len(set(fiber.channels)) == 1: + # Fuse all the layer normalizations into a group normalization + self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels)) + else: + # Use multiple layer normalizations + self.layer_norms = nn.ModuleDict({ + str(degree): nn.LayerNorm(channels) + for degree, channels in fiber + }) + + def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]: + with nvtx_range('NormSE3'): + output = {} + if hasattr(self, 'group_norm'): + # Compute per-degree norms of features + norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP) + for d in self.fiber.degrees] + fused_norms = torch.cat(norms, dim=-2) + + # Transform the norms only + new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1) + new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2) + + # Scale features to the new norms + for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees): + output[str(d)] = features[str(d)] / norm * new_norm + else: + for degree, feat in features.items(): + norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP) + new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1)) + output[degree] = new_norm * feat / norm + + return output diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/pooling.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/pooling.py index 273aab6c..e42c5383 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/pooling.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/pooling.py @@ -1,53 +1,53 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - -from typing import Dict, Literal - -import torch.nn as nn -from dgl import DGLGraph -from dgl.nn.pytorch import AvgPooling, MaxPooling -from torch import Tensor - - -class GPooling(nn.Module): - """ - Graph max/average pooling on a given feature type. - The average can be taken for any feature type, and equivariance will be maintained. - The maximum can only be taken for invariant features (type 0). - If you want max-pooling for type > 0 features, look into Vector Neurons. - """ - - def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'): - """ - :param feat_type: Feature type to pool - :param pool: Type of pooling: max or avg - """ - super().__init__() - assert pool in ['max', 'avg'], f'Unknown pooling: {pool}' - assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance' - self.feat_type = feat_type - self.pool = MaxPooling() if pool == 'max' else AvgPooling() - - def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor: - pooled = self.pool(graph, features[str(self.feat_type)]) - return pooled.squeeze(dim=-1) +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + +from typing import Dict, Literal + +import torch.nn as nn +from dgl import DGLGraph +from dgl.nn.pytorch import AvgPooling, MaxPooling +from torch import Tensor + + +class GPooling(nn.Module): + """ + Graph max/average pooling on a given feature type. + The average can be taken for any feature type, and equivariance will be maintained. + The maximum can only be taken for invariant features (type 0). + If you want max-pooling for type > 0 features, look into Vector Neurons. + """ + + def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'): + """ + :param feat_type: Feature type to pool + :param pool: Type of pooling: max or avg + """ + super().__init__() + assert pool in ['max', 'avg'], f'Unknown pooling: {pool}' + assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance' + self.feat_type = feat_type + self.pool = MaxPooling() if pool == 'max' else AvgPooling() + + def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor: + pooled = self.pool(graph, features[str(self.feat_type)]) + return pooled.squeeze(dim=-1) diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/transformer.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/transformer.py index bec7cc8e..f02bd84e 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/transformer.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/transformer.py @@ -1,222 +1,223 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - -import logging -from typing import Optional, Literal, Dict - -import torch -import torch.nn as nn -from dgl import DGLGraph -from torch import Tensor - -from se3_transformer.model.basis import get_basis, update_basis_with_fused -from se3_transformer.model.layers.attention import AttentionBlockSE3 -from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel -from se3_transformer.model.layers.norm import NormSE3 -from se3_transformer.model.layers.pooling import GPooling -from se3_transformer.runtime.utils import str2bool -from se3_transformer.model.fiber import Fiber - - -class Sequential(nn.Sequential): - """ Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """ - - def forward(self, input, *args, **kwargs): - for module in self: - input = module(input, *args, **kwargs) - return input - - -def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None): - """ Add relative positions to existing edge features """ - edge_features = edge_features.copy() if edge_features else {} - r = relative_pos.norm(dim=-1, keepdim=True) - if '0' in edge_features: - edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1) - else: - edge_features['0'] = r[..., None] - - return edge_features - - -class SE3Transformer(nn.Module): - def __init__(self, - num_layers: int, - fiber_in: Fiber, - fiber_hidden: Fiber, - fiber_out: Fiber, - num_heads: int, - channels_div: int, - fiber_edge: Fiber = Fiber({}), - return_type: Optional[int] = None, - pooling: Optional[Literal['avg', 'max']] = None, - norm: bool = True, - use_layer_norm: bool = True, - tensor_cores: bool = False, - low_memory: bool = False, - **kwargs): - """ - :param num_layers: Number of attention layers - :param fiber_in: Input fiber description - :param fiber_hidden: Hidden fiber description - :param fiber_out: Output fiber description - :param fiber_edge: Input edge fiber description - :param num_heads: Number of attention heads - :param channels_div: Channels division before feeding to attention layer - :param return_type: Return only features of this type - :param pooling: 'avg' or 'max' graph pooling before MLP layers - :param norm: Apply a normalization layer after each attention block - :param use_layer_norm: Apply layer normalization between MLP layers - :param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases) - :param low_memory: If True, will use slower ops that use less memory - """ - super().__init__() - self.num_layers = num_layers - self.fiber_edge = fiber_edge - self.num_heads = num_heads - self.channels_div = channels_div - self.return_type = return_type - self.pooling = pooling - self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees) - self.tensor_cores = tensor_cores - self.low_memory = low_memory - - if low_memory and not tensor_cores: - logging.warning('Low memory mode will have no effect with no Tensor Cores') - - # Fully fused convolutions when using Tensor Cores (and not low memory mode) - fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL - - graph_modules = [] - for i in range(num_layers): - graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in, - fiber_out=fiber_hidden, - fiber_edge=fiber_edge, - num_heads=num_heads, - channels_div=channels_div, - use_layer_norm=use_layer_norm, - max_degree=self.max_degree, - fuse_level=fuse_level)) - if norm: - graph_modules.append(NormSE3(fiber_hidden)) - fiber_in = fiber_hidden - - graph_modules.append(ConvSE3(fiber_in=fiber_in, - fiber_out=fiber_out, - fiber_edge=fiber_edge, - self_interaction=True, - use_layer_norm=use_layer_norm, - max_degree=self.max_degree)) - self.graph_modules = Sequential(*graph_modules) - - if pooling is not None: - assert return_type is not None, 'return_type must be specified when pooling' - self.pooling_module = GPooling(pool=pooling, feat_type=return_type) - - def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor], - edge_feats: Optional[Dict[str, Tensor]] = None, - basis: Optional[Dict[str, Tensor]] = None): - # Compute bases in case they weren't precomputed as part of the data loading - basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False, - use_pad_trick=self.tensor_cores and not self.low_memory, - amp=torch.is_autocast_enabled()) - - # Add fused bases (per output degree, per input degree, and fully fused) to the dict - basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory, - fully_fused=self.tensor_cores and not self.low_memory) - - edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats) - - node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis) - - if self.pooling is not None: - return self.pooling_module(node_feats, graph=graph) - - if self.return_type is not None: - return node_feats[str(self.return_type)] - - return node_feats - - @staticmethod - def add_argparse_args(parser): - parser.add_argument('--num_layers', type=int, default=7, - help='Number of stacked Transformer layers') - parser.add_argument('--num_heads', type=int, default=8, - help='Number of heads in self-attention') - parser.add_argument('--channels_div', type=int, default=2, - help='Channels division before feeding to attention layer') - parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'], - help='Type of graph pooling') - parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False, - help='Apply a normalization layer after each attention block') - parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False, - help='Apply layer normalization between MLP layers') - parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False, - help='If true, will use fused ops that are slower but that use less memory ' - '(expect 25 percent less memory). ' - 'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs') - - return parser - - -class SE3TransformerPooled(nn.Module): - def __init__(self, - fiber_in: Fiber, - fiber_out: Fiber, - fiber_edge: Fiber, - num_degrees: int, - num_channels: int, - output_dim: int, - **kwargs): - super().__init__() - kwargs['pooling'] = kwargs['pooling'] or 'max' - self.transformer = SE3Transformer( - fiber_in=fiber_in, - fiber_hidden=Fiber.create(num_degrees, num_channels), - fiber_out=fiber_out, - fiber_edge=fiber_edge, - return_type=0, - **kwargs - ) - - n_out_features = fiber_out.num_features - self.mlp = nn.Sequential( - nn.Linear(n_out_features, n_out_features), - nn.ReLU(), - nn.Linear(n_out_features, output_dim) - ) - - def forward(self, graph, node_feats, edge_feats, basis=None): - feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1) - y = self.mlp(feats).squeeze(-1) - return y - - @staticmethod - def add_argparse_args(parent_parser): - parser = parent_parser.add_argument_group("Model architecture") - SE3Transformer.add_argparse_args(parser) - parser.add_argument('--num_degrees', - help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]', - type=int, default=4) - parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32) - return parent_parser +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + +import logging +from typing import Optional, Literal, Dict + +import torch +import torch.nn as nn +from dgl import DGLGraph +from torch import Tensor + +from se3_transformer.model.basis import get_basis, update_basis_with_fused +from se3_transformer.model.layers.attention import AttentionBlockSE3 +from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel +from se3_transformer.model.layers.norm import NormSE3 +from se3_transformer.model.layers.pooling import GPooling +from se3_transformer.runtime.utils import str2bool +from se3_transformer.model.fiber import Fiber + + +class Sequential(nn.Sequential): + """ Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """ + + def forward(self, input, *args, **kwargs): + for module in self: + input = module(input, *args, **kwargs) + return input + + +def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None): + """ Add relative positions to existing edge features """ + edge_features = edge_features.copy() if edge_features else {} + r = relative_pos.norm(dim=-1, keepdim=True) + if '0' in edge_features: + edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1) + else: + edge_features['0'] = r[..., None] + + return edge_features + + +class SE3Transformer(nn.Module): + def __init__(self, + num_layers: int, + fiber_in: Fiber, + fiber_hidden: Fiber, + fiber_out: Fiber, + num_heads: int, + channels_div: int, + fiber_edge: Fiber = Fiber({}), + return_type: Optional[int] = None, + pooling: Optional[Literal['avg', 'max']] = None, + norm: bool = True, + use_layer_norm: bool = True, + tensor_cores: bool = False, + low_memory: bool = False, + **kwargs): + """ + :param num_layers: Number of attention layers + :param fiber_in: Input fiber description + :param fiber_hidden: Hidden fiber description + :param fiber_out: Output fiber description + :param fiber_edge: Input edge fiber description + :param num_heads: Number of attention heads + :param channels_div: Channels division before feeding to attention layer + :param return_type: Return only features of this type + :param pooling: 'avg' or 'max' graph pooling before MLP layers + :param norm: Apply a normalization layer after each attention block + :param use_layer_norm: Apply layer normalization between MLP layers + :param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases) + :param low_memory: If True, will use slower ops that use less memory + """ + super().__init__() + self.num_layers = num_layers + self.fiber_edge = fiber_edge + self.num_heads = num_heads + self.channels_div = channels_div + self.return_type = return_type + self.pooling = pooling + self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees) + self.tensor_cores = tensor_cores + self.low_memory = low_memory + + if low_memory: + self.fuse_level = ConvSE3FuseLevel.NONE + else: + # Fully fused convolutions when using Tensor Cores (and not low memory mode) + self.fuse_level = ConvSE3FuseLevel.FULL if tensor_cores else ConvSE3FuseLevel.PARTIAL + + graph_modules = [] + for i in range(num_layers): + graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in, + fiber_out=fiber_hidden, + fiber_edge=fiber_edge, + num_heads=num_heads, + channels_div=channels_div, + use_layer_norm=use_layer_norm, + max_degree=self.max_degree, + fuse_level=self.fuse_level, + low_memory=low_memory)) + if norm: + graph_modules.append(NormSE3(fiber_hidden)) + fiber_in = fiber_hidden + + graph_modules.append(ConvSE3(fiber_in=fiber_in, + fiber_out=fiber_out, + fiber_edge=fiber_edge, + self_interaction=True, + use_layer_norm=use_layer_norm, + max_degree=self.max_degree)) + self.graph_modules = Sequential(*graph_modules) + + if pooling is not None: + assert return_type is not None, 'return_type must be specified when pooling' + self.pooling_module = GPooling(pool=pooling, feat_type=return_type) + + def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor], + edge_feats: Optional[Dict[str, Tensor]] = None, + basis: Optional[Dict[str, Tensor]] = None): + # Compute bases in case they weren't precomputed as part of the data loading + basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False, + use_pad_trick=self.tensor_cores and not self.low_memory, + amp=torch.is_autocast_enabled()) + + # Add fused bases (per output degree, per input degree, and fully fused) to the dict + basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory, + fully_fused=self.fuse_level == ConvSE3FuseLevel.FULL) + + edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats) + + node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis) + + if self.pooling is not None: + return self.pooling_module(node_feats, graph=graph) + + if self.return_type is not None: + return node_feats[str(self.return_type)] + + return node_feats + + @staticmethod + def add_argparse_args(parser): + parser.add_argument('--num_layers', type=int, default=7, + help='Number of stacked Transformer layers') + parser.add_argument('--num_heads', type=int, default=8, + help='Number of heads in self-attention') + parser.add_argument('--channels_div', type=int, default=2, + help='Channels division before feeding to attention layer') + parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'], + help='Type of graph pooling') + parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False, + help='Apply a normalization layer after each attention block') + parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False, + help='Apply layer normalization between MLP layers') + parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False, + help='If true, will use fused ops that are slower but that use less memory ' + '(expect 25 percent less memory). ' + 'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs') + + return parser + + +class SE3TransformerPooled(nn.Module): + def __init__(self, + fiber_in: Fiber, + fiber_out: Fiber, + fiber_edge: Fiber, + num_degrees: int, + num_channels: int, + output_dim: int, + **kwargs): + super().__init__() + kwargs['pooling'] = kwargs['pooling'] or 'max' + self.transformer = SE3Transformer( + fiber_in=fiber_in, + fiber_hidden=Fiber.create(num_degrees, num_channels), + fiber_out=fiber_out, + fiber_edge=fiber_edge, + return_type=0, + **kwargs + ) + + n_out_features = fiber_out.num_features + self.mlp = nn.Sequential( + nn.Linear(n_out_features, n_out_features), + nn.ReLU(), + nn.Linear(n_out_features, output_dim) + ) + + def forward(self, graph, node_feats, edge_feats, basis=None): + feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1) + y = self.mlp(feats).squeeze(-1) + return y + + @staticmethod + def add_argparse_args(parent_parser): + parser = parent_parser.add_argument_group("Model architecture") + SE3Transformer.add_argparse_args(parser) + parser.add_argument('--num_degrees', + help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]', + type=int, default=4) + parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32) + return parent_parser diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/arguments.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/arguments.py index 7730ddcf..eea0b36c 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/arguments.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/arguments.py @@ -1,70 +1,72 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - -import argparse -import pathlib - -from se3_transformer.data_loading import QM9DataModule -from se3_transformer.model import SE3TransformerPooled -from se3_transformer.runtime.utils import str2bool - -PARSER = argparse.ArgumentParser(description='SE(3)-Transformer') - -paths = PARSER.add_argument_group('Paths') -paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'), - help='Directory where the data is located or should be downloaded') -paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'), - help='Directory where the results logs should be saved') -paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json', - help='Name for the resulting DLLogger JSON file') -paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None, - help='File where the checkpoint should be saved') -paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None, - help='File of the checkpoint to be loaded') - -optimizer = PARSER.add_argument_group('Optimizer') -optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam') -optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002) -optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None) -optimizer.add_argument('--momentum', type=float, default=0.9) -optimizer.add_argument('--weight_decay', type=float, default=0.1) - -PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs') -PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size') -PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally') -PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers') - -PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision') -PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms') -PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation') -PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs') -PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=20, - help='Do an evaluation round every N epochs') -PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False, - help='Minimize stdout output') - -PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False, - help='Benchmark mode') - -QM9DataModule.add_argparse_args(PARSER) -SE3TransformerPooled.add_argparse_args(PARSER) +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + +import argparse +import pathlib + +from se3_transformer.data_loading import QM9DataModule +from se3_transformer.model import SE3TransformerPooled +from se3_transformer.runtime.utils import str2bool + +PARSER = argparse.ArgumentParser(description='SE(3)-Transformer') + +paths = PARSER.add_argument_group('Paths') +paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'), + help='Directory where the data is located or should be downloaded') +paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'), + help='Directory where the results logs should be saved') +paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json', + help='Name for the resulting DLLogger JSON file') +paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None, + help='File where the checkpoint should be saved') +paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None, + help='File of the checkpoint to be loaded') + +optimizer = PARSER.add_argument_group('Optimizer') +optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam') +optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002) +optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None) +optimizer.add_argument('--momentum', type=float, default=0.9) +optimizer.add_argument('--weight_decay', type=float, default=0.1) + +PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs') +PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size') +PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally') +PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers') + +PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision') +PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms') +PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation') +PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs') +PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=20, + help='Do an evaluation round every N epochs') +PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False, + help='Minimize stdout output') +PARSER.add_argument('--wandb', type=str2bool, nargs='?', const=True, default=False, + help='Enable W&B logging') + +PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False, + help='Benchmark mode') + +QM9DataModule.add_argparse_args(PARSER) +SE3TransformerPooled.add_argparse_args(PARSER) diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/inference.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/inference.py index 21e9125b..e81088b4 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/inference.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/inference.py @@ -32,7 +32,7 @@ from tqdm import tqdm from se3_transformer.runtime import gpu_affinity from se3_transformer.runtime.arguments import PARSER from se3_transformer.runtime.callbacks import BaseCallback -from se3_transformer.runtime.loggers import DLLogger +from se3_transformer.runtime.loggers import DLLogger, WandbLogger, LoggerCollection from se3_transformer.runtime.utils import to_cuda, get_local_rank @@ -87,7 +87,10 @@ if __name__ == '__main__': major_cc, minor_cc = torch.cuda.get_device_capability() - logger = DLLogger(args.log_dir, filename=args.dllogger_name) + loggers = [DLLogger(save_dir=args.log_dir, filename=args.dllogger_name)] + if args.wandb: + loggers.append(WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer')) + logger = LoggerCollection(loggers) datamodule = QM9DataModule(**vars(args)) model = SE3TransformerPooled( fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}), @@ -108,6 +111,7 @@ if __name__ == '__main__': nproc_per_node = torch.cuda.device_count() affinity = gpu_affinity.set_affinity(local_rank, nproc_per_node) model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) + model._set_static_graph() test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader() evaluate(model, diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/training.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/training.py index ebb4f501..4472ed24 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/training.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/training.py @@ -1,240 +1,241 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - -import logging -import pathlib -from typing import List - -import numpy as np -import torch -import torch.distributed as dist -import torch.nn as nn -from apex.optimizers import FusedAdam, FusedLAMB -from torch.nn.modules.loss import _Loss -from torch.nn.parallel import DistributedDataParallel -from torch.optim import Optimizer -from torch.utils.data import DataLoader, DistributedSampler -from tqdm import tqdm - -from se3_transformer.data_loading import QM9DataModule -from se3_transformer.model import SE3TransformerPooled -from se3_transformer.model.fiber import Fiber -from se3_transformer.runtime import gpu_affinity -from se3_transformer.runtime.arguments import PARSER -from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \ - PerformanceCallback -from se3_transformer.runtime.inference import evaluate -from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger -from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \ - using_tensor_cores, increase_l2_fetch_granularity - - -def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]): - """ Saves model, optimizer and epoch states to path (only once per node) """ - if get_local_rank() == 0: - state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict() - checkpoint = { - 'state_dict': state_dict, - 'optimizer_state_dict': optimizer.state_dict(), - 'epoch': epoch - } - for callback in callbacks: - callback.on_checkpoint_save(checkpoint) - - torch.save(checkpoint, str(path)) - logging.info(f'Saved checkpoint to {str(path)}') - - -def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]): - """ Loads model, optimizer and epoch states from path """ - checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'}) - if isinstance(model, DistributedDataParallel): - model.module.load_state_dict(checkpoint['state_dict']) - else: - model.load_state_dict(checkpoint['state_dict']) - optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - - for callback in callbacks: - callback.on_checkpoint_load(checkpoint) - - logging.info(f'Loaded checkpoint from {str(path)}') - return checkpoint['epoch'] - - -def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args): - losses = [] - for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch', - desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)): - *inputs, target = to_cuda(batch) - - for callback in callbacks: - callback.on_batch_start() - - with torch.cuda.amp.autocast(enabled=args.amp): - pred = model(*inputs) - loss = loss_fn(pred, target) / args.accumulate_grad_batches - - grad_scaler.scale(loss).backward() - - # gradient accumulation - if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader): - if args.gradient_clip: - grad_scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip) - - grad_scaler.step(optimizer) - grad_scaler.update() - model.zero_grad(set_to_none=True) - - losses.append(loss.item()) - - return np.mean(losses) - - -def train(model: nn.Module, - loss_fn: _Loss, - train_dataloader: DataLoader, - val_dataloader: DataLoader, - callbacks: List[BaseCallback], - logger: Logger, - args): - device = torch.cuda.current_device() - model.to(device=device) - local_rank = get_local_rank() - world_size = dist.get_world_size() if dist.is_initialized() else 1 - - if dist.is_initialized(): - model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) - - model.train() - grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp) - if args.optimizer == 'adam': - optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), - weight_decay=args.weight_decay) - elif args.optimizer == 'lamb': - optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), - weight_decay=args.weight_decay) - else: - optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, - weight_decay=args.weight_decay) - - epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0 - - for callback in callbacks: - callback.on_fit_start(optimizer, args) - - for epoch_idx in range(epoch_start, args.epochs): - if isinstance(train_dataloader.sampler, DistributedSampler): - train_dataloader.sampler.set_epoch(epoch_idx) - - loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args) - if dist.is_initialized(): - loss = torch.tensor(loss, dtype=torch.float, device=device) - torch.distributed.all_reduce(loss) - loss = (loss / world_size).item() - - logging.info(f'Train loss: {loss}') - logger.log_metrics({'train loss': loss}, epoch_idx) - - for callback in callbacks: - callback.on_epoch_end() - - if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \ - and (epoch_idx + 1) % args.ckpt_interval == 0: - save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks) - - if not args.benchmark and ((args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0) or epoch_idx + 1 == args.epochs): - evaluate(model, val_dataloader, callbacks, args) - model.train() - - for callback in callbacks: - callback.on_validation_end(epoch_idx) - - if args.save_ckpt_path is not None and not args.benchmark: - save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks) - - for callback in callbacks: - callback.on_fit_end() - - -def print_parameters_count(model): - num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) - logging.info(f'Number of trainable parameters: {num_params_trainable}') - - -if __name__ == '__main__': - is_distributed = init_distributed() - local_rank = get_local_rank() - args = PARSER.parse_args() - - logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO) - - logging.info('====== SE(3)-Transformer ======') - logging.info('| Training procedure |') - logging.info('===============================') - - if args.seed is not None: - logging.info(f'Using seed {args.seed}') - seed_everything(args.seed) - - logger = LoggerCollection([ - DLLogger(save_dir=args.log_dir, filename=args.dllogger_name), - WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer') - ]) - - datamodule = QM9DataModule(**vars(args)) - model = SE3TransformerPooled( - fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}), - fiber_out=Fiber({0: args.num_degrees * args.num_channels}), - fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}), - output_dim=1, - tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively - **vars(args) - ) - loss_fn = nn.L1Loss() - - if args.benchmark: - logging.info('Running benchmark mode') - world_size = dist.get_world_size() if dist.is_initialized() else 1 - callbacks = [PerformanceCallback(logger, args.batch_size * world_size)] - else: - callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'), - QM9LRSchedulerCallback(logger, epochs=args.epochs)] - - if is_distributed: - gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count()) - - print_parameters_count(model) - logger.log_hyperparams(vars(args)) - increase_l2_fetch_granularity() - train(model, - loss_fn, - datamodule.train_dataloader(), - datamodule.val_dataloader(), - callbacks, - logger, - args) - - logging.info('Training finished successfully') - - +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + +import logging +import pathlib +from typing import List + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from apex.optimizers import FusedAdam, FusedLAMB +from torch.nn.modules.loss import _Loss +from torch.nn.parallel import DistributedDataParallel +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from tqdm import tqdm + +from se3_transformer.data_loading import QM9DataModule +from se3_transformer.model import SE3TransformerPooled +from se3_transformer.model.fiber import Fiber +from se3_transformer.runtime import gpu_affinity +from se3_transformer.runtime.arguments import PARSER +from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \ + PerformanceCallback +from se3_transformer.runtime.inference import evaluate +from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger +from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \ + using_tensor_cores, increase_l2_fetch_granularity + + +def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]): + """ Saves model, optimizer and epoch states to path (only once per node) """ + if get_local_rank() == 0: + state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict() + checkpoint = { + 'state_dict': state_dict, + 'optimizer_state_dict': optimizer.state_dict(), + 'epoch': epoch + } + for callback in callbacks: + callback.on_checkpoint_save(checkpoint) + + torch.save(checkpoint, str(path)) + logging.info(f'Saved checkpoint to {str(path)}') + + +def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]): + """ Loads model, optimizer and epoch states from path """ + checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'}) + if isinstance(model, DistributedDataParallel): + model.module.load_state_dict(checkpoint['state_dict']) + else: + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + for callback in callbacks: + callback.on_checkpoint_load(checkpoint) + + logging.info(f'Loaded checkpoint from {str(path)}') + return checkpoint['epoch'] + + +def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args): + losses = [] + for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch', + desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)): + *inputs, target = to_cuda(batch) + + for callback in callbacks: + callback.on_batch_start() + + with torch.cuda.amp.autocast(enabled=args.amp): + pred = model(*inputs) + loss = loss_fn(pred, target) / args.accumulate_grad_batches + + grad_scaler.scale(loss).backward() + + # gradient accumulation + if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader): + if args.gradient_clip: + grad_scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip) + + grad_scaler.step(optimizer) + grad_scaler.update() + model.zero_grad(set_to_none=True) + + losses.append(loss.item()) + + return np.mean(losses) + + +def train(model: nn.Module, + loss_fn: _Loss, + train_dataloader: DataLoader, + val_dataloader: DataLoader, + callbacks: List[BaseCallback], + logger: Logger, + args): + device = torch.cuda.current_device() + model.to(device=device) + local_rank = get_local_rank() + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + if dist.is_initialized(): + model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) + model._set_static_graph() + + model.train() + grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp) + if args.optimizer == 'adam': + optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), + weight_decay=args.weight_decay) + elif args.optimizer == 'lamb': + optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), + weight_decay=args.weight_decay) + else: + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, + weight_decay=args.weight_decay) + + epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0 + + for callback in callbacks: + callback.on_fit_start(optimizer, args) + + for epoch_idx in range(epoch_start, args.epochs): + if isinstance(train_dataloader.sampler, DistributedSampler): + train_dataloader.sampler.set_epoch(epoch_idx) + + loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, + args) + if dist.is_initialized(): + loss = torch.tensor(loss, dtype=torch.float, device=device) + torch.distributed.all_reduce(loss) + loss = (loss / world_size).item() + + logging.info(f'Train loss: {loss}') + logger.log_metrics({'train loss': loss}, epoch_idx) + + for callback in callbacks: + callback.on_epoch_end() + + if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \ + and (epoch_idx + 1) % args.ckpt_interval == 0: + save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks) + + if not args.benchmark and ( + (args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0) or epoch_idx + 1 == args.epochs): + evaluate(model, val_dataloader, callbacks, args) + model.train() + + for callback in callbacks: + callback.on_validation_end(epoch_idx) + + if args.save_ckpt_path is not None and not args.benchmark: + save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks) + + for callback in callbacks: + callback.on_fit_end() + + +def print_parameters_count(model): + num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + logging.info(f'Number of trainable parameters: {num_params_trainable}') + + +if __name__ == '__main__': + is_distributed = init_distributed() + local_rank = get_local_rank() + args = PARSER.parse_args() + + logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO) + + logging.info('====== SE(3)-Transformer ======') + logging.info('| Training procedure |') + logging.info('===============================') + + if args.seed is not None: + logging.info(f'Using seed {args.seed}') + seed_everything(args.seed) + + loggers = [DLLogger(save_dir=args.log_dir, filename=args.dllogger_name)] + if args.wandb: + loggers.append(WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer')) + logger = LoggerCollection(loggers) + + datamodule = QM9DataModule(**vars(args)) + model = SE3TransformerPooled( + fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}), + fiber_out=Fiber({0: args.num_degrees * args.num_channels}), + fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}), + output_dim=1, + tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively + **vars(args) + ) + loss_fn = nn.L1Loss() + + if args.benchmark: + logging.info('Running benchmark mode') + world_size = dist.get_world_size() if dist.is_initialized() else 1 + callbacks = [PerformanceCallback(logger, args.batch_size * world_size)] + else: + callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'), + QM9LRSchedulerCallback(logger, epochs=args.epochs)] + + if is_distributed: + gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count()) + + print_parameters_count(model) + logger.log_hyperparams(vars(args)) + increase_l2_fetch_granularity() + train(model, + loss_fn, + datamodule.train_dataloader(), + datamodule.val_dataloader(), + callbacks, + logger, + args) + + logging.info('Training finished successfully') diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/setup.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/setup.py index 82714897..dc601a41 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/setup.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/setup.py @@ -2,9 +2,9 @@ from setuptools import setup, find_packages setup( name='se3-transformer', - packages=find_packages(), + packages=find_packages(exclude=['tests']), include_package_data=True, - version='1.0.0', + version='1.1.0', description='PyTorch + DGL implementation of SE(3)-Transformers', author='Alexandre Milesi', author_email='alexandrem@nvidia.com', diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/tests/test_equivariance.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/tests/test_equivariance.py index f19e4d94..7fede4a6 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/tests/test_equivariance.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/tests/test_equivariance.py @@ -1,102 +1,106 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - -import torch - -from se3_transformer.model import SE3Transformer -from se3_transformer.model.fiber import Fiber -from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot - -# Tolerances for equivariance error abs( f(x) @ R - f(x @ R) ) -TOL = 1e-3 -CHANNELS, NODES = 32, 512 - - -def _get_outputs(model, R): - feats0 = torch.randn(NODES, CHANNELS, 1) - feats1 = torch.randn(NODES, CHANNELS, 3) - - coords = torch.randn(NODES, 3) - graph = get_random_graph(NODES) - if torch.cuda.is_available(): - feats0 = feats0.cuda() - feats1 = feats1.cuda() - R = R.cuda() - coords = coords.cuda() - graph = graph.to('cuda') - model.cuda() - - graph1 = assign_relative_pos(graph, coords) - out1 = model(graph1, {'0': feats0, '1': feats1}, {}) - graph2 = assign_relative_pos(graph, coords @ R) - out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {}) - - return out1, out2 - - -def _get_model(**kwargs): - return SE3Transformer( - num_layers=4, - fiber_in=Fiber.create(2, CHANNELS), - fiber_hidden=Fiber.create(3, CHANNELS), - fiber_out=Fiber.create(2, CHANNELS), - fiber_edge=Fiber({}), - num_heads=8, - channels_div=2, - **kwargs - ) - - -def test_equivariance(): - model = _get_model() - R = rot(*torch.rand(3)) - if torch.cuda.is_available(): - R = R.cuda() - out1, out2 = _get_outputs(model, R) - - assert torch.allclose(out2['0'], out1['0'], atol=TOL), \ - f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}' - assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \ - f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}' - - -def test_equivariance_pooled(): - model = _get_model(pooling='avg', return_type=1) - R = rot(*torch.rand(3)) - if torch.cuda.is_available(): - R = R.cuda() - out1, out2 = _get_outputs(model, R) - - assert torch.allclose(out2, (out1 @ R), atol=TOL), \ - f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}' - - -def test_invariance_pooled(): - model = _get_model(pooling='avg', return_type=0) - R = rot(*torch.rand(3)) - if torch.cuda.is_available(): - R = R.cuda() - out1, out2 = _get_outputs(model, R) - - assert torch.allclose(out2, out1, atol=TOL), \ - f'type-0 features should be invariant {get_max_diff(out1, out2)}' +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + +import torch + +from se3_transformer.model import SE3Transformer +from se3_transformer.model.fiber import Fiber + +if __package__ is None or __package__ == '': + from utils import get_random_graph, assign_relative_pos, get_max_diff, rot +else: + from .utils import get_random_graph, assign_relative_pos, get_max_diff, rot + +# Tolerances for equivariance error abs( f(x) @ R - f(x @ R) ) +TOL = 1e-3 +CHANNELS, NODES = 32, 512 + + +def _get_outputs(model, R): + feats0 = torch.randn(NODES, CHANNELS, 1) + feats1 = torch.randn(NODES, CHANNELS, 3) + + coords = torch.randn(NODES, 3) + graph = get_random_graph(NODES) + if torch.cuda.is_available(): + feats0 = feats0.cuda() + feats1 = feats1.cuda() + R = R.cuda() + coords = coords.cuda() + graph = graph.to('cuda') + model.cuda() + + graph1 = assign_relative_pos(graph, coords) + out1 = model(graph1, {'0': feats0, '1': feats1}, {}) + graph2 = assign_relative_pos(graph, coords @ R) + out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {}) + + return out1, out2 + + +def _get_model(**kwargs): + return SE3Transformer( + num_layers=4, + fiber_in=Fiber.create(2, CHANNELS), + fiber_hidden=Fiber.create(3, CHANNELS), + fiber_out=Fiber.create(2, CHANNELS), + fiber_edge=Fiber({}), + num_heads=8, + channels_div=2, + **kwargs + ) + + +def test_equivariance(): + model = _get_model() + R = rot(*torch.rand(3)) + if torch.cuda.is_available(): + R = R.cuda() + out1, out2 = _get_outputs(model, R) + + assert torch.allclose(out2['0'], out1['0'], atol=TOL), \ + f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}' + assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \ + f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}' + + +def test_equivariance_pooled(): + model = _get_model(pooling='avg', return_type=1) + R = rot(*torch.rand(3)) + if torch.cuda.is_available(): + R = R.cuda() + out1, out2 = _get_outputs(model, R) + + assert torch.allclose(out2, (out1 @ R), atol=TOL), \ + f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}' + + +def test_invariance_pooled(): + model = _get_model(pooling='avg', return_type=0) + R = rot(*torch.rand(3)) + if torch.cuda.is_available(): + R = R.cuda() + out1, out2 = _get_outputs(model, R) + + assert torch.allclose(out2, out1, atol=TOL), \ + f'type-0 features should be invariant {get_max_diff(out1, out2)}' diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/tests/utils.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/tests/utils.py index d72bebc3..195f0aef 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/tests/utils.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/tests/utils.py @@ -1,60 +1,60 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -# -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES -# SPDX-License-Identifier: MIT - -import dgl -import torch - - -def get_random_graph(N, num_edges_factor=18): - graph = dgl.transform.remove_self_loop(dgl.rand_graph(N, N * num_edges_factor)) - return graph - - -def assign_relative_pos(graph, coords): - src, dst = graph.edges() - graph.edata['rel_pos'] = coords[src] - coords[dst] - return graph - - -def get_max_diff(a, b): - return (a - b).abs().max().item() - - -def rot_z(gamma): - return torch.tensor([ - [torch.cos(gamma), -torch.sin(gamma), 0], - [torch.sin(gamma), torch.cos(gamma), 0], - [0, 0, 1] - ], dtype=gamma.dtype) - - -def rot_y(beta): - return torch.tensor([ - [torch.cos(beta), 0, torch.sin(beta)], - [0, 1, 0], - [-torch.sin(beta), 0, torch.cos(beta)] - ], dtype=beta.dtype) - - -def rot(alpha, beta, gamma): - return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma) +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +# +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# SPDX-License-Identifier: MIT + +import dgl +import torch + + +def get_random_graph(N, num_edges_factor=18): + graph = dgl.transform.remove_self_loop(dgl.rand_graph(N, N * num_edges_factor)) + return graph + + +def assign_relative_pos(graph, coords): + src, dst = graph.edges() + graph.edata['rel_pos'] = coords[src] - coords[dst] + return graph + + +def get_max_diff(a, b): + return (a - b).abs().max().item() + + +def rot_z(gamma): + return torch.tensor([ + [torch.cos(gamma), -torch.sin(gamma), 0], + [torch.sin(gamma), torch.cos(gamma), 0], + [0, 0, 1] + ], dtype=gamma.dtype) + + +def rot_y(beta): + return torch.tensor([ + [torch.cos(beta), 0, torch.sin(beta)], + [0, 1, 0], + [-torch.sin(beta), 0, torch.cos(beta)] + ], dtype=beta.dtype) + + +def rot(alpha, beta, gamma): + return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma) diff --git a/README.md b/README.md index bacaa3ab..85a96ddf 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,7 @@ These examples, along with our NVIDIA deep learning software stack, are provided ## Graph Neural Networks | Models | Framework | A100 | AMP | Multi-GPU | Multi-Node | TRT | ONNX | Triton | DLC | NB | | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- |------------- |------------- |------------- |------------- |------------- | -| [SE(3)-Transformer](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/DrugDiscovery/SE3Transformer) | PyTorch | Yes | Yes | Yes | - | - | - | - | - | - | +| [SE(3)-Transformer](https://github.com/NVIDIA/DeepLearningExamples/tree/master/DGLPyTorch/DrugDiscovery/SE3Transformer) | PyTorch | Yes | Yes | Yes | - | - | - | - | - | - | ## NVIDIA support