906 lines
33 KiB
Python
Executable file
906 lines
33 KiB
Python
Executable file
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from contextlib import nullcontext
|
|
|
|
from typing import Dict
|
|
|
|
from equivariant_attention.from_se3cnn import utils_steerable
|
|
from equivariant_attention.fibers import Fiber, fiber2head
|
|
from utils.utils_logging import log_gradient_norm
|
|
|
|
import dgl
|
|
import dgl.function as fn
|
|
from dgl.nn.pytorch.softmax import edge_softmax
|
|
from dgl.nn.pytorch.glob import AvgPooling, MaxPooling
|
|
|
|
from packaging import version
|
|
|
|
|
|
def get_basis(G, max_degree, compute_gradients):
|
|
"""Precompute the SE(3)-equivariant weight basis, W_J^lk(x)
|
|
|
|
This is called by get_basis_and_r().
|
|
|
|
Args:
|
|
G: DGL graph instance of type dgl.DGLGraph
|
|
max_degree: non-negative int for degree of highest feature type
|
|
compute_gradients: boolean, whether to compute gradients during basis construction
|
|
Returns:
|
|
dict of equivariant bases. Keys are in the form 'd_in,d_out'. Values are
|
|
tensors of shape (batch_size, 1, 2*d_out+1, 1, 2*d_in+1, number_of_bases)
|
|
where the 1's will later be broadcast to the number of output and input
|
|
channels
|
|
"""
|
|
if compute_gradients:
|
|
context = nullcontext()
|
|
else:
|
|
context = torch.no_grad()
|
|
|
|
with context:
|
|
cloned_d = torch.clone(G.edata['d'])
|
|
|
|
if G.edata['d'].requires_grad:
|
|
cloned_d.requires_grad_()
|
|
log_gradient_norm(cloned_d, 'Basis computation flow')
|
|
|
|
# Relative positional encodings (vector)
|
|
r_ij = utils_steerable.get_spherical_from_cartesian_torch(cloned_d)
|
|
# Spherical harmonic basis
|
|
Y = utils_steerable.precompute_sh(r_ij, 2*max_degree)
|
|
device = Y[0].device
|
|
|
|
basis = {}
|
|
for d_in in range(max_degree+1):
|
|
for d_out in range(max_degree+1):
|
|
K_Js = []
|
|
for J in range(abs(d_in-d_out), d_in+d_out+1):
|
|
# Get spherical harmonic projection matrices
|
|
Q_J = utils_steerable._basis_transformation_Q_J(J, d_in, d_out)
|
|
Q_J = Q_J.float().T.to(device)
|
|
|
|
# Create kernel from spherical harmonics
|
|
K_J = torch.matmul(Y[J], Q_J)
|
|
K_Js.append(K_J)
|
|
|
|
# Reshape so can take linear combinations with a dot product
|
|
size = (-1, 1, 2*d_out+1, 1, 2*d_in+1, 2*min(d_in, d_out)+1)
|
|
basis[f'{d_in},{d_out}'] = torch.stack(K_Js, -1).view(*size)
|
|
return basis
|
|
|
|
|
|
def get_r(G):
|
|
"""Compute internodal distances"""
|
|
cloned_d = torch.clone(G.edata['d'])
|
|
|
|
if G.edata['d'].requires_grad:
|
|
cloned_d.requires_grad_()
|
|
log_gradient_norm(cloned_d, 'Neural networks flow')
|
|
|
|
return torch.sqrt(torch.sum(cloned_d**2, -1, keepdim=True))
|
|
|
|
|
|
def get_basis_and_r(G, max_degree, compute_gradients=False):
|
|
"""Return equivariant weight basis (basis) and internodal distances (r).
|
|
|
|
Call this function *once* at the start of each forward pass of the model.
|
|
It computes the equivariant weight basis, W_J^lk(x), and internodal
|
|
distances, needed to compute varphi_J^lk(x), of eqn 8 of
|
|
https://arxiv.org/pdf/2006.10503.pdf. The return values of this function
|
|
can be shared as input across all SE(3)-Transformer layers in a model.
|
|
|
|
Args:
|
|
G: DGL graph instance of type dgl.DGLGraph()
|
|
max_degree: non-negative int for degree of highest feature-type
|
|
compute_gradients: controls whether to compute gradients during basis construction
|
|
Returns:
|
|
dict of equivariant bases, keys are in form '<d_in><d_out>'
|
|
vector of relative distances, ordered according to edge ordering of G
|
|
"""
|
|
basis = get_basis(G, max_degree, compute_gradients)
|
|
r = get_r(G)
|
|
return basis, r
|
|
|
|
|
|
### SE(3) equivariant operations on graphs in DGL
|
|
|
|
class GConvSE3(nn.Module):
|
|
"""A tensor field network layer as a DGL module.
|
|
|
|
GConvSE3 stands for a Graph Convolution SE(3)-equivariant layer. It is the
|
|
equivalent of a linear layer in an MLP, a conv layer in a CNN, or a graph
|
|
conv layer in a GCN.
|
|
|
|
At each node, the activations are split into different "feature types",
|
|
indexed by the SE(3) representation type: non-negative integers 0, 1, 2, ..
|
|
"""
|
|
def __init__(self, f_in, f_out, self_interaction: bool=False, edge_dim: int=0, flavor='skip'):
|
|
"""SE(3)-equivariant Graph Conv Layer
|
|
|
|
Args:
|
|
f_in: list of tuples [(multiplicities, type),...]
|
|
f_out: list of tuples [(multiplicities, type),...]
|
|
self_interaction: include self-interaction in convolution
|
|
edge_dim: number of dimensions for edge embedding
|
|
flavor: allows ['TFN', 'skip'], where 'skip' adds a skip connection
|
|
"""
|
|
super().__init__()
|
|
self.f_in = f_in
|
|
self.f_out = f_out
|
|
self.edge_dim = edge_dim
|
|
self.self_interaction = self_interaction
|
|
self.flavor = flavor
|
|
|
|
# Neighbor -> center weights
|
|
self.kernel_unary = nn.ModuleDict()
|
|
for (mi, di) in self.f_in.structure:
|
|
for (mo, do) in self.f_out.structure:
|
|
self.kernel_unary[f'({di},{do})'] = PairwiseConv(di, mi, do, mo, edge_dim=edge_dim)
|
|
|
|
# Center -> center weights
|
|
self.kernel_self = nn.ParameterDict()
|
|
if self_interaction:
|
|
assert self.flavor in ['TFN', 'skip']
|
|
if self.flavor == 'TFN':
|
|
for m_out, d_out in self.f_out.structure:
|
|
W = nn.Parameter(torch.randn(1, m_out, m_out) / np.sqrt(m_out))
|
|
self.kernel_self[f'{d_out}'] = W
|
|
elif self.flavor == 'skip':
|
|
for m_in, d_in in self.f_in.structure:
|
|
if d_in in self.f_out.degrees:
|
|
m_out = self.f_out.structure_dict[d_in]
|
|
W = nn.Parameter(torch.randn(1, m_out, m_in) / np.sqrt(m_in))
|
|
self.kernel_self[f'{d_in}'] = W
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
return f'GConvSE3(structure={self.f_out}, self_interaction={self.self_interaction})'
|
|
|
|
|
|
def udf_u_mul_e(self, d_out):
|
|
"""Compute the convolution for a single output feature type.
|
|
|
|
This function is set up as a User Defined Function in DGL.
|
|
|
|
Args:
|
|
d_out: output feature type
|
|
Returns:
|
|
edge -> node function handle
|
|
"""
|
|
def fnc(edges):
|
|
# Neighbor -> center messages
|
|
msg = 0
|
|
for m_in, d_in in self.f_in.structure:
|
|
src = edges.src[f'{d_in}'].view(-1, m_in*(2*d_in+1), 1)
|
|
edge = edges.data[f'({d_in},{d_out})']
|
|
msg = msg + torch.matmul(edge, src)
|
|
msg = msg.view(msg.shape[0], -1, 2*d_out+1)
|
|
|
|
# Center -> center messages
|
|
if self.self_interaction:
|
|
if f'{d_out}' in self.kernel_self.keys():
|
|
if self.flavor == 'TFN':
|
|
W = self.kernel_self[f'{d_out}']
|
|
msg = torch.matmul(W, msg)
|
|
if self.flavor == 'skip':
|
|
dst = edges.dst[f'{d_out}']
|
|
W = self.kernel_self[f'{d_out}']
|
|
msg = msg + torch.matmul(W, dst)
|
|
|
|
return {'msg': msg.view(msg.shape[0], -1, 2*d_out+1)}
|
|
return fnc
|
|
|
|
def forward(self, h, G=None, r=None, basis=None, **kwargs):
|
|
"""Forward pass of the linear layer
|
|
|
|
Args:
|
|
G: minibatch of (homo)graphs
|
|
h: dict of features
|
|
r: inter-atomic distances
|
|
basis: pre-computed Q * Y
|
|
Returns:
|
|
tensor with new features [B, n_points, n_features_out]
|
|
"""
|
|
with G.local_scope():
|
|
# Add node features to local graph scope
|
|
for k, v in h.items():
|
|
G.ndata[k] = v
|
|
|
|
# Add edge features
|
|
if 'w' in G.edata.keys():
|
|
w = G.edata['w']
|
|
feat = torch.cat([w, r], -1)
|
|
else:
|
|
feat = torch.cat([r, ], -1)
|
|
|
|
for (mi, di) in self.f_in.structure:
|
|
for (mo, do) in self.f_out.structure:
|
|
etype = f'({di},{do})'
|
|
G.edata[etype] = self.kernel_unary[etype](feat, basis)
|
|
|
|
# Perform message-passing for each output feature type
|
|
for d in self.f_out.degrees:
|
|
G.update_all(self.udf_u_mul_e(d), fn.mean('msg', f'out{d}'))
|
|
|
|
return {f'{d}': G.ndata[f'out{d}'] for d in self.f_out.degrees}
|
|
|
|
|
|
class RadialFunc(nn.Module):
|
|
"""NN parameterized radial profile function."""
|
|
def __init__(self, num_freq, in_dim, out_dim, edge_dim: int=0):
|
|
"""NN parameterized radial profile function.
|
|
|
|
Args:
|
|
num_freq: number of output frequencies
|
|
in_dim: multiplicity of input (num input channels)
|
|
out_dim: multiplicity of output (num output channels)
|
|
edge_dim: number of dimensions for edge embedding
|
|
"""
|
|
super().__init__()
|
|
self.num_freq = num_freq
|
|
self.in_dim = in_dim
|
|
self.mid_dim = 32
|
|
self.out_dim = out_dim
|
|
self.edge_dim = edge_dim
|
|
|
|
self.net = nn.Sequential(nn.Linear(self.edge_dim+1,self.mid_dim),
|
|
BN(self.mid_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(self.mid_dim,self.mid_dim),
|
|
BN(self.mid_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(self.mid_dim,self.num_freq*in_dim*out_dim))
|
|
|
|
nn.init.kaiming_uniform_(self.net[0].weight)
|
|
nn.init.kaiming_uniform_(self.net[3].weight)
|
|
nn.init.kaiming_uniform_(self.net[6].weight)
|
|
|
|
def __repr__(self):
|
|
return f"RadialFunc(edge_dim={self.edge_dim}, in_dim={self.in_dim}, out_dim={self.out_dim})"
|
|
|
|
def forward(self, x):
|
|
y = self.net(x)
|
|
return y.view(-1, self.out_dim, 1, self.in_dim, 1, self.num_freq)
|
|
|
|
|
|
class PairwiseConv(nn.Module):
|
|
"""SE(3)-equivariant convolution between two single-type features"""
|
|
def __init__(self, degree_in: int, nc_in: int, degree_out: int,
|
|
nc_out: int, edge_dim: int=0):
|
|
"""SE(3)-equivariant convolution between a pair of feature types.
|
|
|
|
This layer performs a convolution from nc_in features of type degree_in
|
|
to nc_out features of type degree_out.
|
|
|
|
Args:
|
|
degree_in: degree of input fiber
|
|
nc_in: number of channels on input
|
|
degree_out: degree of out order
|
|
nc_out: number of channels on output
|
|
edge_dim: number of dimensions for edge embedding
|
|
"""
|
|
super().__init__()
|
|
# Log settings
|
|
self.degree_in = degree_in
|
|
self.degree_out = degree_out
|
|
self.nc_in = nc_in
|
|
self.nc_out = nc_out
|
|
|
|
# Functions of the degree
|
|
self.num_freq = 2*min(degree_in, degree_out) + 1
|
|
self.d_out = 2*degree_out + 1
|
|
self.edge_dim = edge_dim
|
|
|
|
# Radial profile function
|
|
self.rp = RadialFunc(self.num_freq, nc_in, nc_out, self.edge_dim)
|
|
|
|
def forward(self, feat, basis):
|
|
# Get radial weights
|
|
R = self.rp(feat)
|
|
kernel = torch.sum(R * basis[f'{self.degree_in},{self.degree_out}'], -1)
|
|
return kernel.view(kernel.shape[0], self.d_out*self.nc_out, -1)
|
|
|
|
|
|
class G1x1SE3(nn.Module):
|
|
"""Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.
|
|
|
|
This is equivalent to a self-interaction layer in TensorField Networks.
|
|
"""
|
|
def __init__(self, f_in, f_out, learnable=True):
|
|
"""SE(3)-equivariant 1x1 convolution.
|
|
|
|
Args:
|
|
f_in: input Fiber() of feature multiplicities and types
|
|
f_out: output Fiber() of feature multiplicities and types
|
|
"""
|
|
super().__init__()
|
|
self.f_in = f_in
|
|
self.f_out = f_out
|
|
|
|
# Linear mappings: 1 per output feature type
|
|
self.transform = nn.ParameterDict()
|
|
for m_out, d_out in self.f_out.structure:
|
|
m_in = self.f_in.structure_dict[d_out]
|
|
self.transform[str(d_out)] = nn.Parameter(torch.randn(m_out, m_in) / np.sqrt(m_in), requires_grad=learnable)
|
|
|
|
def __repr__(self):
|
|
return f"G1x1SE3(structure={self.f_out})"
|
|
|
|
def forward(self, features, **kwargs):
|
|
output = {}
|
|
for k, v in features.items():
|
|
if str(k) in self.transform.keys():
|
|
output[k] = torch.matmul(self.transform[str(k)], v)
|
|
return output
|
|
|
|
|
|
class GNormBias(nn.Module):
|
|
"""Norm-based SE(3)-equivariant nonlinearity with only learned biases."""
|
|
|
|
def __init__(self, fiber, nonlin=nn.ReLU(inplace=True),
|
|
num_layers: int = 0):
|
|
"""Initializer.
|
|
|
|
Args:
|
|
fiber: Fiber() of feature multiplicities and types
|
|
nonlin: nonlinearity to use everywhere
|
|
num_layers: non-negative number of linear layers in fnc
|
|
"""
|
|
super().__init__()
|
|
self.fiber = fiber
|
|
self.nonlin = nonlin
|
|
self.num_layers = num_layers
|
|
|
|
# Regularization for computing phase: gradients explode otherwise
|
|
self.eps = 1e-12
|
|
|
|
# Norm mappings: 1 per feature type
|
|
self.bias = nn.ParameterDict()
|
|
for m, d in self.fiber.structure:
|
|
self.bias[str(d)] = nn.Parameter(torch.randn(m).view(1, m))
|
|
|
|
def __repr__(self):
|
|
return f"GNormTFN()"
|
|
|
|
|
|
def forward(self, features, **kwargs):
|
|
output = {}
|
|
for k, v in features.items():
|
|
# Compute the norms and normalized features
|
|
# v shape: [...,m , 2*k+1]
|
|
norm = v.norm(2, -1, keepdim=True).clamp_min(self.eps).expand_as(v)
|
|
phase = v / norm
|
|
|
|
# Transform on norms
|
|
# transformed = self.transform[str(k)](norm[..., 0]).unsqueeze(-1)
|
|
transformed = self.nonlin(norm[..., 0] + self.bias[str(k)])
|
|
|
|
# Nonlinearity on norm
|
|
output[k] = (transformed.unsqueeze(-1) * phase).view(*v.shape)
|
|
|
|
return output
|
|
|
|
|
|
class GAttentiveSelfInt(nn.Module):
|
|
|
|
def __init__(self, f_in, f_out):
|
|
"""SE(3)-equivariant 1x1 convolution.
|
|
|
|
Args:
|
|
f_in: input Fiber() of feature multiplicities and types
|
|
f_out: output Fiber() of feature multiplicities and types
|
|
"""
|
|
super().__init__()
|
|
self.f_in = f_in
|
|
self.f_out = f_out
|
|
self.nonlin = nn.LeakyReLU()
|
|
self.num_layers = 2
|
|
self.eps = 1e-12 # regularisation for phase: gradients explode otherwise
|
|
|
|
# one network for attention weights per degree
|
|
self.transform = nn.ModuleDict()
|
|
for o, m_in in self.f_in.structure_dict.items():
|
|
m_out = self.f_out.structure_dict[o]
|
|
self.transform[str(o)] = self._build_net(m_in, m_out)
|
|
|
|
def __repr__(self):
|
|
return f"AttentiveSelfInteractionSE3(in={self.f_in}, out={self.f_out})"
|
|
|
|
def _build_net(self, m_in: int, m_out):
|
|
n_hidden = m_in * m_out
|
|
cur_inpt = m_in * m_in
|
|
net = []
|
|
for i in range(1, self.num_layers):
|
|
net.append(nn.LayerNorm(int(cur_inpt)))
|
|
net.append(self.nonlin)
|
|
# TODO: implement cleaner init
|
|
net.append(
|
|
nn.Linear(cur_inpt, n_hidden, bias=(i == self.num_layers - 1)))
|
|
nn.init.kaiming_uniform_(net[-1].weight)
|
|
cur_inpt = n_hidden
|
|
return nn.Sequential(*net)
|
|
|
|
def forward(self, features, **kwargs):
|
|
output = {}
|
|
for k, v in features.items():
|
|
# v shape: [..., m, 2*k+1]
|
|
first_dims = v.shape[:-2]
|
|
m_in = self.f_in.structure_dict[int(k)]
|
|
m_out = self.f_out.structure_dict[int(k)]
|
|
assert v.shape[-2] == m_in
|
|
assert v.shape[-1] == 2 * int(k) + 1
|
|
|
|
# Compute the norms and normalized features
|
|
#norm = v.norm(p=2, dim=-1, keepdim=True).clamp_min(self.eps).expand_as(v)
|
|
#phase = v / norm # [..., m, 2*k+1]
|
|
scalars = torch.einsum('...ac,...bc->...ab', [v, v]) # [..., m_in, m_in]
|
|
scalars = scalars.view(*first_dims, m_in*m_in) # [..., m_in*m_in]
|
|
sign = scalars.sign()
|
|
scalars = scalars.abs_().clamp_min(self.eps)
|
|
scalars *= sign
|
|
|
|
# perform attention
|
|
att_weights = self.transform[str(k)](scalars) # [..., m_out*m_in]
|
|
att_weights = att_weights.view(*first_dims, m_out, m_in) # [..., m_out, m_in]
|
|
att_weights = F.softmax(input=att_weights, dim=-1)
|
|
# shape [..., m_out, 2*k+1]
|
|
# output[k] = torch.einsum('...nm,...md->...nd', [att_weights, phase])
|
|
output[k] = torch.einsum('...nm,...md->...nd', [att_weights, v])
|
|
|
|
return output
|
|
|
|
|
|
|
|
class GNormSE3(nn.Module):
|
|
"""Graph Norm-based SE(3)-equivariant nonlinearity.
|
|
|
|
Nonlinearities are important in SE(3) equivariant GCNs. They are also quite
|
|
expensive to compute, so it is convenient for them to share resources with
|
|
other layers, such as normalization. The general workflow is as follows:
|
|
|
|
> for feature type in features:
|
|
> norm, phase <- feature
|
|
> output = fnc(norm) * phase
|
|
|
|
where fnc: {R+}^m -> R^m is a learnable map from m norms to m scalars.
|
|
"""
|
|
def __init__(self, fiber, nonlin=nn.ReLU(inplace=True), num_layers: int=0):
|
|
"""Initializer.
|
|
|
|
Args:
|
|
fiber: Fiber() of feature multiplicities and types
|
|
nonlin: nonlinearity to use everywhere
|
|
num_layers: non-negative number of linear layers in fnc
|
|
"""
|
|
super().__init__()
|
|
self.fiber = fiber
|
|
self.nonlin = nonlin
|
|
self.num_layers = num_layers
|
|
|
|
# Regularization for computing phase: gradients explode otherwise
|
|
self.eps = 1e-12
|
|
|
|
# Norm mappings: 1 per feature type
|
|
self.transform = nn.ModuleDict()
|
|
for m, d in self.fiber.structure:
|
|
self.transform[str(d)] = self._build_net(int(m))
|
|
|
|
def __repr__(self):
|
|
return f"GNormSE3(num_layers={self.num_layers}, nonlin={self.nonlin})"
|
|
|
|
def _build_net(self, m: int):
|
|
net = []
|
|
for i in range(self.num_layers):
|
|
net.append(BN(int(m)))
|
|
net.append(self.nonlin)
|
|
# TODO: implement cleaner init
|
|
net.append(nn.Linear(m, m, bias=(i==self.num_layers-1)))
|
|
nn.init.kaiming_uniform_(net[-1].weight)
|
|
if self.num_layers == 0:
|
|
net.append(BN(int(m)))
|
|
net.append(self.nonlin)
|
|
return nn.Sequential(*net)
|
|
|
|
def forward(self, features, **kwargs):
|
|
output = {}
|
|
for k, v in features.items():
|
|
# Compute the norms and normalized features
|
|
# v shape: [...,m , 2*k+1]
|
|
norm = v.norm(2, -1, keepdim=True).clamp_min(self.eps).expand_as(v)
|
|
phase = v / norm
|
|
|
|
# Transform on norms
|
|
transformed = self.transform[str(k)](norm[...,0]).unsqueeze(-1)
|
|
|
|
# Nonlinearity on norm
|
|
output[k] = (transformed * phase).view(*v.shape)
|
|
|
|
return output
|
|
|
|
|
|
class BN(nn.Module):
|
|
"""SE(3)-equvariant batch/layer normalization"""
|
|
def __init__(self, m):
|
|
"""SE(3)-equvariant batch/layer normalization
|
|
|
|
Args:
|
|
m: int for number of output channels
|
|
"""
|
|
super().__init__()
|
|
self.bn = nn.LayerNorm(m)
|
|
|
|
def forward(self, x):
|
|
return self.bn(x)
|
|
|
|
|
|
class GConvSE3Partial(nn.Module):
|
|
"""Graph SE(3)-equivariant node -> edge layer"""
|
|
def __init__(self, f_in, f_out, edge_dim: int=0, x_ij=None):
|
|
"""SE(3)-equivariant partial convolution.
|
|
|
|
A partial convolution computes the inner product between a kernel and
|
|
each input channel, without summing over the result from each input
|
|
channel. This unfolded structure makes it amenable to be used for
|
|
computing the value-embeddings of the attention mechanism.
|
|
|
|
Args:
|
|
f_in: list of tuples [(multiplicities, type),...]
|
|
f_out: list of tuples [(multiplicities, type),...]
|
|
"""
|
|
super().__init__()
|
|
self.f_out = f_out
|
|
self.edge_dim = edge_dim
|
|
|
|
# adding/concatinating relative position to feature vectors
|
|
# 'cat' concatenates relative position & existing feature vector
|
|
# 'add' adds it, but only if multiplicity > 1
|
|
assert x_ij in [None, 'cat', 'add']
|
|
self.x_ij = x_ij
|
|
if x_ij == 'cat':
|
|
self.f_in = Fiber.combine(f_in, Fiber(structure=[(1,1)]))
|
|
else:
|
|
self.f_in = f_in
|
|
|
|
# Node -> edge weights
|
|
self.kernel_unary = nn.ModuleDict()
|
|
for (mi, di) in self.f_in.structure:
|
|
for (mo, do) in self.f_out.structure:
|
|
self.kernel_unary[f'({di},{do})'] = PairwiseConv(di, mi, do, mo, edge_dim=edge_dim)
|
|
|
|
def __repr__(self):
|
|
return f'GConvSE3Partial(structure={self.f_out})'
|
|
|
|
def udf_u_mul_e(self, d_out):
|
|
"""Compute the partial convolution for a single output feature type.
|
|
|
|
This function is set up as a User Defined Function in DGL.
|
|
|
|
Args:
|
|
d_out: output feature type
|
|
Returns:
|
|
node -> edge function handle
|
|
"""
|
|
def fnc(edges):
|
|
# Neighbor -> center messages
|
|
msg = 0
|
|
for m_in, d_in in self.f_in.structure:
|
|
# if type 1 and flag set, add relative position as feature
|
|
if self.x_ij == 'cat' and d_in == 1:
|
|
# relative positions
|
|
rel = (edges.dst['x'] - edges.src['x']).view(-1, 3, 1)
|
|
m_ori = m_in - 1
|
|
if m_ori == 0:
|
|
# no type 1 input feature, just use relative position
|
|
src = rel
|
|
else:
|
|
# features of src node, shape [edges, m_in*(2l+1), 1]
|
|
src = edges.src[f'{d_in}'].view(-1, m_ori*(2*d_in+1), 1)
|
|
# add to feature vector
|
|
src = torch.cat([src, rel], dim=1)
|
|
elif self.x_ij == 'add' and d_in == 1 and m_in > 1:
|
|
src = edges.src[f'{d_in}'].view(-1, m_in*(2*d_in+1), 1)
|
|
rel = (edges.dst['x'] - edges.src['x']).view(-1, 3, 1)
|
|
src[..., :3, :1] = src[..., :3, :1] + rel
|
|
else:
|
|
src = edges.src[f'{d_in}'].view(-1, m_in*(2*d_in+1), 1)
|
|
edge = edges.data[f'({d_in},{d_out})']
|
|
msg = msg + torch.matmul(edge, src)
|
|
msg = msg.view(msg.shape[0], -1, 2*d_out+1)
|
|
|
|
return {f'out{d_out}': msg.view(msg.shape[0], -1, 2*d_out+1)}
|
|
return fnc
|
|
|
|
def forward(self, h, G=None, r=None, basis=None, **kwargs):
|
|
"""Forward pass of the linear layer
|
|
|
|
Args:
|
|
h: dict of node-features
|
|
G: minibatch of (homo)graphs
|
|
r: inter-atomic distances
|
|
basis: pre-computed Q * Y
|
|
Returns:
|
|
tensor with new features [B, n_points, n_features_out]
|
|
"""
|
|
with G.local_scope():
|
|
# Add node features to local graph scope
|
|
for k, v in h.items():
|
|
G.ndata[k] = v
|
|
|
|
# Add edge features
|
|
if 'w' in G.edata.keys():
|
|
w = G.edata['w'] # shape: [#edges_in_batch, #bond_types]
|
|
feat = torch.cat([w, r], -1)
|
|
else:
|
|
feat = torch.cat([r, ], -1)
|
|
for (mi, di) in self.f_in.structure:
|
|
for (mo, do) in self.f_out.structure:
|
|
etype = f'({di},{do})'
|
|
G.edata[etype] = self.kernel_unary[etype](feat, basis)
|
|
|
|
# Perform message-passing for each output feature type
|
|
for d in self.f_out.degrees:
|
|
G.apply_edges(self.udf_u_mul_e(d))
|
|
|
|
return {f'{d}': G.edata[f'out{d}'] for d in self.f_out.degrees}
|
|
|
|
|
|
class GMABSE3(nn.Module):
|
|
"""An SE(3)-equivariant multi-headed self-attention module for DGL graphs."""
|
|
def __init__(self, f_value: Fiber, f_key: Fiber, n_heads: int):
|
|
"""SE(3)-equivariant MAB (multi-headed attention block) layer.
|
|
|
|
Args:
|
|
f_value: Fiber() object for value-embeddings
|
|
f_key: Fiber() object for key-embeddings
|
|
n_heads: number of heads
|
|
"""
|
|
super().__init__()
|
|
self.f_value = f_value
|
|
self.f_key = f_key
|
|
self.n_heads = n_heads
|
|
self.new_dgl = version.parse(dgl.__version__) > version.parse('0.4.4')
|
|
|
|
def __repr__(self):
|
|
return f'GMABSE3(n_heads={self.n_heads}, structure={self.f_value})'
|
|
|
|
def udf_u_mul_e(self, d_out):
|
|
"""Compute the weighted sum for a single output feature type.
|
|
|
|
This function is set up as a User Defined Function in DGL.
|
|
|
|
Args:
|
|
d_out: output feature type
|
|
Returns:
|
|
edge -> node function handle
|
|
"""
|
|
def fnc(edges):
|
|
# Neighbor -> center messages
|
|
attn = edges.data['a']
|
|
value = edges.data[f'v{d_out}']
|
|
|
|
# Apply attention weights
|
|
msg = attn.unsqueeze(-1).unsqueeze(-1) * value
|
|
|
|
return {'m': msg}
|
|
return fnc
|
|
|
|
def forward(self, v, k: Dict=None, q: Dict=None, G=None, **kwargs):
|
|
"""Forward pass of the linear layer
|
|
|
|
Args:
|
|
G: minibatch of (homo)graphs
|
|
v: dict of value edge-features
|
|
k: dict of key edge-features
|
|
q: dict of query node-features
|
|
Returns:
|
|
tensor with new features [B, n_points, n_features_out]
|
|
"""
|
|
with G.local_scope():
|
|
# Add node features to local graph scope
|
|
## We use the stacked tensor representation for attention
|
|
for m, d in self.f_value.structure:
|
|
G.edata[f'v{d}'] = v[f'{d}'].view(-1, self.n_heads, m//self.n_heads, 2*d+1)
|
|
G.edata['k'] = fiber2head(k, self.n_heads, self.f_key, squeeze=True) # [edges, heads, channels](?)
|
|
G.ndata['q'] = fiber2head(q, self.n_heads, self.f_key, squeeze=True) # [nodes, heads, channels](?)
|
|
|
|
# Compute attention weights
|
|
## Inner product between (key) neighborhood and (query) center
|
|
G.apply_edges(fn.e_dot_v('k', 'q', 'e'))
|
|
|
|
## Apply softmax
|
|
e = G.edata.pop('e')
|
|
if self.new_dgl:
|
|
# in dgl 5.3, e has an extra dimension compared to dgl 4.3
|
|
# the following, we get rid of this be reshaping
|
|
n_edges = G.edata['k'].shape[0]
|
|
e = e.view([n_edges, self.n_heads])
|
|
e = e / np.sqrt(self.f_key.n_features)
|
|
G.edata['a'] = edge_softmax(G, e)
|
|
|
|
# Perform attention-weighted message-passing
|
|
for d in self.f_value.degrees:
|
|
G.update_all(self.udf_u_mul_e(d), fn.sum('m', f'out{d}'))
|
|
|
|
output = {}
|
|
for m, d in self.f_value.structure:
|
|
output[f'{d}'] = G.ndata[f'out{d}'].view(-1, m, 2*d+1)
|
|
|
|
return output
|
|
|
|
|
|
class GSE3Res(nn.Module):
|
|
"""Graph attention block with SE(3)-equivariance and skip connection"""
|
|
def __init__(self, f_in: Fiber, f_out: Fiber, edge_dim: int=0, div: float=4,
|
|
n_heads: int=1, learnable_skip=True, skip='cat', selfint='1x1', x_ij=None):
|
|
super().__init__()
|
|
self.f_in = f_in
|
|
self.f_out = f_out
|
|
self.div = div
|
|
self.n_heads = n_heads
|
|
self.skip = skip # valid: 'cat', 'sum', None
|
|
|
|
# f_mid_out has same structure as 'f_out' but #channels divided by 'div'
|
|
# this will be used for the values
|
|
f_mid_out = {k: int(v // div) for k, v in self.f_out.structure_dict.items()}
|
|
self.f_mid_out = Fiber(dictionary=f_mid_out)
|
|
|
|
# f_mid_in has same structure as f_mid_out, but only degrees which are in f_in
|
|
# this will be used for keys and queries
|
|
# (queries are merely projected, hence degrees have to match input)
|
|
f_mid_in = {d: m for d, m in f_mid_out.items() if d in self.f_in.degrees}
|
|
self.f_mid_in = Fiber(dictionary=f_mid_in)
|
|
|
|
self.edge_dim = edge_dim
|
|
|
|
self.GMAB = nn.ModuleDict()
|
|
|
|
# Projections
|
|
self.GMAB['v'] = GConvSE3Partial(f_in, self.f_mid_out, edge_dim=edge_dim, x_ij=x_ij)
|
|
self.GMAB['k'] = GConvSE3Partial(f_in, self.f_mid_in, edge_dim=edge_dim, x_ij=x_ij)
|
|
self.GMAB['q'] = G1x1SE3(f_in, self.f_mid_in)
|
|
|
|
# Attention
|
|
self.GMAB['attn'] = GMABSE3(self.f_mid_out, self.f_mid_in, n_heads=n_heads)
|
|
|
|
# Skip connections
|
|
if self.skip == 'cat':
|
|
self.cat = GCat(self.f_mid_out, f_in)
|
|
if selfint == 'att':
|
|
self.project = GAttentiveSelfInt(self.cat.f_out, f_out)
|
|
elif selfint == '1x1':
|
|
self.project = G1x1SE3(self.cat.f_out, f_out, learnable=learnable_skip)
|
|
elif self.skip == 'sum':
|
|
self.project = G1x1SE3(self.f_mid_out, f_out, learnable=learnable_skip)
|
|
self.add = GSum(f_out, f_in)
|
|
# the following checks whether the skip connection would change
|
|
# the output fibre strucure; the reason can be that the input has
|
|
# more channels than the ouput (for at least one degree); this would
|
|
# then cause a (hard to debug) error in the next layer
|
|
assert self.add.f_out.structure_dict == f_out.structure_dict, \
|
|
'skip connection would change output structure'
|
|
|
|
def forward(self, features, G, **kwargs):
|
|
# Embeddings
|
|
v = self.GMAB['v'](features, G=G, **kwargs)
|
|
k = self.GMAB['k'](features, G=G, **kwargs)
|
|
q = self.GMAB['q'](features, G=G)
|
|
|
|
# Attention
|
|
z = self.GMAB['attn'](v, k=k, q=q, G=G)
|
|
|
|
if self.skip == 'cat':
|
|
z = self.cat(z, features)
|
|
z = self.project(z)
|
|
elif self.skip == 'sum':
|
|
# Skip + residual
|
|
z = self.project(z)
|
|
z = self.add(z, features)
|
|
return z
|
|
|
|
### Helper and wrapper functions
|
|
|
|
class GSum(nn.Module):
|
|
"""SE(3)-equvariant graph residual sum function."""
|
|
def __init__(self, f_x: Fiber, f_y: Fiber):
|
|
"""SE(3)-equvariant graph residual sum function.
|
|
|
|
Args:
|
|
f_x: Fiber() object for fiber of summands
|
|
f_y: Fiber() object for fiber of summands
|
|
"""
|
|
super().__init__()
|
|
self.f_x = f_x
|
|
self.f_y = f_y
|
|
self.f_out = Fiber.combine_max(f_x, f_y)
|
|
|
|
def __repr__(self):
|
|
return f"GSum(structure={self.f_out})"
|
|
|
|
def forward(self, x, y):
|
|
out = {}
|
|
for k in self.f_out.degrees:
|
|
k = str(k)
|
|
if (k in x) and (k in y):
|
|
if x[k].shape[1] > y[k].shape[1]:
|
|
diff = x[k].shape[1] - y[k].shape[1]
|
|
zeros = torch.zeros(x[k].shape[0], diff, x[k].shape[2]).to(y[k].device)
|
|
y[k] = torch.cat([y[k], zeros], 1)
|
|
elif x[k].shape[1] < y[k].shape[1]:
|
|
diff = y[k].shape[1] - x[k].shape[1]
|
|
zeros = torch.zeros(x[k].shape[0], diff, x[k].shape[2]).to(y[k].device)
|
|
x[k] = torch.cat([x[k], zeros], 1)
|
|
|
|
out[k] = x[k] + y[k]
|
|
elif k in x:
|
|
out[k] = x[k]
|
|
elif k in y:
|
|
out[k] = y[k]
|
|
return out
|
|
|
|
|
|
class GCat(nn.Module):
|
|
"""Concat only degrees which are in f_x"""
|
|
def __init__(self, f_x: Fiber, f_y: Fiber):
|
|
super().__init__()
|
|
self.f_x = f_x
|
|
self.f_y = f_y
|
|
f_out = {}
|
|
for k in f_x.degrees:
|
|
f_out[k] = f_x.dict[k]
|
|
if k in f_y.degrees:
|
|
f_out[k] += f_y.dict[k]
|
|
self.f_out = Fiber(dictionary=f_out)
|
|
|
|
def __repr__(self):
|
|
return f"GCat(structure={self.f_out})"
|
|
|
|
def forward(self, x, y):
|
|
out = {}
|
|
for k in self.f_out.degrees:
|
|
k = str(k)
|
|
if k in y:
|
|
out[k] = torch.cat([x[k], y[k]], 1)
|
|
else:
|
|
out[k] = x[k]
|
|
return out
|
|
|
|
|
|
class GAvgPooling(nn.Module):
|
|
"""Graph Average Pooling module."""
|
|
def __init__(self, type='0'):
|
|
super().__init__()
|
|
self.pool = AvgPooling()
|
|
self.type = type
|
|
|
|
def forward(self, features, G, **kwargs):
|
|
if self.type == '0':
|
|
h = features['0'][...,-1]
|
|
pooled = self.pool(G, h)
|
|
elif self.type == '1':
|
|
pooled = []
|
|
for i in range(3):
|
|
h_i = features['1'][..., i]
|
|
pooled.append(self.pool(G, h_i).unsqueeze(-1))
|
|
pooled = torch.cat(pooled, axis=-1)
|
|
pooled = {'1': pooled}
|
|
else:
|
|
print('GAvgPooling for type > 0 not implemented')
|
|
exit()
|
|
return pooled
|
|
|
|
|
|
class GMaxPooling(nn.Module):
|
|
"""Graph Max Pooling module."""
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.pool = MaxPooling()
|
|
|
|
def forward(self, features, G, **kwargs):
|
|
h = features['0'][...,-1]
|
|
return self.pool(G, h)
|
|
|
|
|