2021-10-21 09:11:30 +02:00
|
|
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
|
|
#
|
|
|
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
|
|
|
# copy of this software and associated documentation files (the "Software"),
|
|
|
|
# to deal in the Software without restriction, including without limitation
|
|
|
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
|
|
|
# and/or sell copies of the Software, and to permit persons to whom the
|
|
|
|
# Software is furnished to do so, subject to the following conditions:
|
|
|
|
#
|
|
|
|
# The above copyright notice and this permission notice shall be included in
|
|
|
|
# all copies or substantial portions of the Software.
|
|
|
|
#
|
|
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
|
|
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
|
|
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
|
|
|
# DEALINGS IN THE SOFTWARE.
|
|
|
|
#
|
|
|
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
|
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
|
|
|
|
import logging
|
|
|
|
from typing import Optional, Literal, Dict
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from dgl import DGLGraph
|
|
|
|
from torch import Tensor
|
|
|
|
|
|
|
|
from se3_transformer.model.basis import get_basis, update_basis_with_fused
|
|
|
|
from se3_transformer.model.layers.attention import AttentionBlockSE3
|
|
|
|
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
|
|
|
from se3_transformer.model.layers.norm import NormSE3
|
|
|
|
from se3_transformer.model.layers.pooling import GPooling
|
|
|
|
from se3_transformer.runtime.utils import str2bool
|
|
|
|
from se3_transformer.model.fiber import Fiber
|
|
|
|
|
|
|
|
|
|
|
|
class Sequential(nn.Sequential):
|
|
|
|
""" Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """
|
|
|
|
|
|
|
|
def forward(self, input, *args, **kwargs):
|
|
|
|
for module in self:
|
|
|
|
input = module(input, *args, **kwargs)
|
|
|
|
return input
|
|
|
|
|
|
|
|
|
|
|
|
def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None):
|
|
|
|
""" Add relative positions to existing edge features """
|
|
|
|
edge_features = edge_features.copy() if edge_features else {}
|
|
|
|
r = relative_pos.norm(dim=-1, keepdim=True)
|
|
|
|
if '0' in edge_features:
|
|
|
|
edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1)
|
|
|
|
else:
|
|
|
|
edge_features['0'] = r[..., None]
|
|
|
|
|
|
|
|
return edge_features
|
|
|
|
|
|
|
|
|
|
|
|
class SE3Transformer(nn.Module):
|
|
|
|
def __init__(self,
|
|
|
|
num_layers: int,
|
|
|
|
fiber_in: Fiber,
|
|
|
|
fiber_hidden: Fiber,
|
|
|
|
fiber_out: Fiber,
|
|
|
|
num_heads: int,
|
|
|
|
channels_div: int,
|
|
|
|
fiber_edge: Fiber = Fiber({}),
|
|
|
|
return_type: Optional[int] = None,
|
|
|
|
pooling: Optional[Literal['avg', 'max']] = None,
|
|
|
|
norm: bool = True,
|
|
|
|
use_layer_norm: bool = True,
|
|
|
|
tensor_cores: bool = False,
|
|
|
|
low_memory: bool = False,
|
|
|
|
**kwargs):
|
|
|
|
"""
|
|
|
|
:param num_layers: Number of attention layers
|
|
|
|
:param fiber_in: Input fiber description
|
|
|
|
:param fiber_hidden: Hidden fiber description
|
|
|
|
:param fiber_out: Output fiber description
|
|
|
|
:param fiber_edge: Input edge fiber description
|
|
|
|
:param num_heads: Number of attention heads
|
|
|
|
:param channels_div: Channels division before feeding to attention layer
|
|
|
|
:param return_type: Return only features of this type
|
|
|
|
:param pooling: 'avg' or 'max' graph pooling before MLP layers
|
|
|
|
:param norm: Apply a normalization layer after each attention block
|
|
|
|
:param use_layer_norm: Apply layer normalization between MLP layers
|
|
|
|
:param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases)
|
|
|
|
:param low_memory: If True, will use slower ops that use less memory
|
|
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.num_layers = num_layers
|
|
|
|
self.fiber_edge = fiber_edge
|
|
|
|
self.num_heads = num_heads
|
|
|
|
self.channels_div = channels_div
|
|
|
|
self.return_type = return_type
|
|
|
|
self.pooling = pooling
|
|
|
|
self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees)
|
|
|
|
self.tensor_cores = tensor_cores
|
|
|
|
self.low_memory = low_memory
|
|
|
|
|
|
|
|
if low_memory and not tensor_cores:
|
|
|
|
logging.warning('Low memory mode will have no effect with no Tensor Cores')
|
|
|
|
|
|
|
|
# Fully fused convolutions when using Tensor Cores (and not low memory mode)
|
|
|
|
fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL
|
|
|
|
|
|
|
|
graph_modules = []
|
|
|
|
for i in range(num_layers):
|
|
|
|
graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
|
|
|
|
fiber_out=fiber_hidden,
|
|
|
|
fiber_edge=fiber_edge,
|
|
|
|
num_heads=num_heads,
|
|
|
|
channels_div=channels_div,
|
|
|
|
use_layer_norm=use_layer_norm,
|
|
|
|
max_degree=self.max_degree,
|
|
|
|
fuse_level=fuse_level))
|
|
|
|
if norm:
|
|
|
|
graph_modules.append(NormSE3(fiber_hidden))
|
|
|
|
fiber_in = fiber_hidden
|
|
|
|
|
|
|
|
graph_modules.append(ConvSE3(fiber_in=fiber_in,
|
|
|
|
fiber_out=fiber_out,
|
|
|
|
fiber_edge=fiber_edge,
|
|
|
|
self_interaction=True,
|
|
|
|
use_layer_norm=use_layer_norm,
|
|
|
|
max_degree=self.max_degree))
|
|
|
|
self.graph_modules = Sequential(*graph_modules)
|
|
|
|
|
|
|
|
if pooling is not None:
|
|
|
|
assert return_type is not None, 'return_type must be specified when pooling'
|
|
|
|
self.pooling_module = GPooling(pool=pooling, feat_type=return_type)
|
|
|
|
|
|
|
|
def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor],
|
|
|
|
edge_feats: Optional[Dict[str, Tensor]] = None,
|
|
|
|
basis: Optional[Dict[str, Tensor]] = None):
|
|
|
|
# Compute bases in case they weren't precomputed as part of the data loading
|
|
|
|
basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False,
|
|
|
|
use_pad_trick=self.tensor_cores and not self.low_memory,
|
|
|
|
amp=torch.is_autocast_enabled())
|
|
|
|
|
|
|
|
# Add fused bases (per output degree, per input degree, and fully fused) to the dict
|
|
|
|
basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
|
|
|
|
fully_fused=self.tensor_cores and not self.low_memory)
|
|
|
|
|
|
|
|
edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
|
|
|
|
|
|
|
|
node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis)
|
|
|
|
|
|
|
|
if self.pooling is not None:
|
|
|
|
return self.pooling_module(node_feats, graph=graph)
|
|
|
|
|
|
|
|
if self.return_type is not None:
|
|
|
|
return node_feats[str(self.return_type)]
|
|
|
|
|
|
|
|
return node_feats
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def add_argparse_args(parser):
|
|
|
|
parser.add_argument('--num_layers', type=int, default=7,
|
|
|
|
help='Number of stacked Transformer layers')
|
|
|
|
parser.add_argument('--num_heads', type=int, default=8,
|
|
|
|
help='Number of heads in self-attention')
|
|
|
|
parser.add_argument('--channels_div', type=int, default=2,
|
|
|
|
help='Channels division before feeding to attention layer')
|
|
|
|
parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'],
|
|
|
|
help='Type of graph pooling')
|
|
|
|
parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False,
|
|
|
|
help='Apply a normalization layer after each attention block')
|
|
|
|
parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False,
|
|
|
|
help='Apply layer normalization between MLP layers')
|
|
|
|
parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False,
|
|
|
|
help='If true, will use fused ops that are slower but that use less memory '
|
|
|
|
'(expect 25 percent less memory). '
|
|
|
|
'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs')
|
|
|
|
|
|
|
|
return parser
|
|
|
|
|
|
|
|
|
|
|
|
class SE3TransformerPooled(nn.Module):
|
|
|
|
def __init__(self,
|
|
|
|
fiber_in: Fiber,
|
|
|
|
fiber_out: Fiber,
|
|
|
|
fiber_edge: Fiber,
|
|
|
|
num_degrees: int,
|
|
|
|
num_channels: int,
|
|
|
|
output_dim: int,
|
|
|
|
**kwargs):
|
|
|
|
super().__init__()
|
|
|
|
kwargs['pooling'] = kwargs['pooling'] or 'max'
|
|
|
|
self.transformer = SE3Transformer(
|
|
|
|
fiber_in=fiber_in,
|
|
|
|
fiber_hidden=Fiber.create(num_degrees, num_channels),
|
|
|
|
fiber_out=fiber_out,
|
|
|
|
fiber_edge=fiber_edge,
|
|
|
|
return_type=0,
|
|
|
|
**kwargs
|
|
|
|
)
|
|
|
|
|
|
|
|
n_out_features = fiber_out.num_features
|
|
|
|
self.mlp = nn.Sequential(
|
|
|
|
nn.Linear(n_out_features, n_out_features),
|
|
|
|
nn.ReLU(),
|
|
|
|
nn.Linear(n_out_features, output_dim)
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(self, graph, node_feats, edge_feats, basis=None):
|
|
|
|
feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1)
|
|
|
|
y = self.mlp(feats).squeeze(-1)
|
|
|
|
return y
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def add_argparse_args(parent_parser):
|
|
|
|
parser = parent_parser.add_argument_group("Model architecture")
|
|
|
|
SE3Transformer.add_argparse_args(parser)
|
|
|
|
parser.add_argument('--num_degrees',
|
|
|
|
help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]',
|
|
|
|
type=int, default=4)
|
|
|
|
parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32)
|
|
|
|
return parent_parser
|