DeepLearningExamples/DGLPyTorch/DrugDiscovery/RoseTTAFold/network/SE3_network.py
2021-10-15 15:46:41 +02:00

109 lines
4.7 KiB
Python

import torch
import torch.nn as nn
from equivariant_attention.modules import get_basis_and_r, GSE3Res, GNormBias
from equivariant_attention.modules import GConvSE3, GNormSE3
from equivariant_attention.fibers import Fiber
class TFN(nn.Module):
"""SE(3) equivariant GCN"""
def __init__(self, num_layers=2, num_channels=32, num_nonlin_layers=1, num_degrees=3,
l0_in_features=32, l0_out_features=32,
l1_in_features=3, l1_out_features=3,
num_edge_features=32, use_self=True):
super().__init__()
# Build the network
self.num_layers = num_layers
self.num_nlayers = num_nonlin_layers
self.num_channels = num_channels
self.num_degrees = num_degrees
self.edge_dim = num_edge_features
self.use_self = use_self
if l1_out_features > 0:
fibers = {'in': Fiber(dictionary={0: l0_in_features, 1: l1_in_features}),
'mid': Fiber(self.num_degrees, self.num_channels),
'out': Fiber(dictionary={0: l0_out_features, 1: l1_out_features})}
else:
fibers = {'in': Fiber(dictionary={0: l0_in_features, 1: l1_in_features}),
'mid': Fiber(self.num_degrees, self.num_channels),
'out': Fiber(dictionary={0: l0_out_features})}
blocks = self._build_gcn(fibers)
self.block0 = blocks
def _build_gcn(self, fibers):
block0 = []
fin = fibers['in']
for i in range(self.num_layers-1):
block0.append(GConvSE3(fin, fibers['mid'], self_interaction=self.use_self, edge_dim=self.edge_dim))
block0.append(GNormSE3(fibers['mid'], num_layers=self.num_nlayers))
fin = fibers['mid']
block0.append(GConvSE3(fibers['mid'], fibers['out'], self_interaction=self.use_self, edge_dim=self.edge_dim))
return nn.ModuleList(block0)
@torch.cuda.amp.autocast(enabled=False)
def forward(self, G, type_0_features, type_1_features):
# Compute equivariant weight basis from relative positions
basis, r = get_basis_and_r(G, self.num_degrees-1)
h = {'0': type_0_features, '1': type_1_features}
for layer in self.block0:
h = layer(h, G=G, r=r, basis=basis)
return h
class SE3Transformer(nn.Module):
"""SE(3) equivariant GCN with attention"""
def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4,
si_m='1x1', si_e='att',
l0_in_features=32, l0_out_features=32,
l1_in_features=3, l1_out_features=3,
num_edge_features=32, x_ij=None):
super().__init__()
# Build the network
self.num_layers = num_layers
self.num_channels = num_channels
self.num_degrees = num_degrees
self.edge_dim = num_edge_features
self.div = div
self.n_heads = n_heads
self.si_m, self.si_e = si_m, si_e
self.x_ij = x_ij
if l1_out_features > 0:
fibers = {'in': Fiber(dictionary={0: l0_in_features, 1: l1_in_features}),
'mid': Fiber(self.num_degrees, self.num_channels),
'out': Fiber(dictionary={0: l0_out_features, 1: l1_out_features})}
else:
fibers = {'in': Fiber(dictionary={0: l0_in_features, 1: l1_in_features}),
'mid': Fiber(self.num_degrees, self.num_channels),
'out': Fiber(dictionary={0: l0_out_features})}
blocks = self._build_gcn(fibers)
self.Gblock = blocks
def _build_gcn(self, fibers):
# Equivariant layers
Gblock = []
fin = fibers['in']
for i in range(self.num_layers):
Gblock.append(GSE3Res(fin, fibers['mid'], edge_dim=self.edge_dim,
div=self.div, n_heads=self.n_heads,
learnable_skip=True, skip='cat',
selfint=self.si_m, x_ij=self.x_ij))
Gblock.append(GNormBias(fibers['mid']))
fin = fibers['mid']
Gblock.append(
GSE3Res(fibers['mid'], fibers['out'], edge_dim=self.edge_dim,
div=1, n_heads=min(1, 2), learnable_skip=True,
skip='cat', selfint=self.si_e, x_ij=self.x_ij))
return nn.ModuleList(Gblock)
@torch.cuda.amp.autocast(enabled=False)
def forward(self, G, type_0_features, type_1_features):
# Compute equivariant weight basis from relative positions
basis, r = get_basis_and_r(G, self.num_degrees-1)
h = {'0': type_0_features, '1': type_1_features}
for layer in self.Gblock:
h = layer(h, G=G, r=r, basis=basis)
return h