DeepLearningExamples/DGLPyTorch/DrugDiscovery/RoseTTAFold/network/Attention_module_w_str.py
2021-10-15 15:46:41 +02:00

481 lines
21 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from Transformer import *
from Transformer import _get_clones
import Transformer
from resnet import ResidualNetwork
from SE3_network import SE3Transformer
from InitStrGenerator import InitStr_Network
import dgl
# Attention module based on AlphaFold2's idea written by Minkyung Baek
# - Iterative MSA feature extraction
# - 1) MSA2Pair: extract pairwise feature from MSA --> added to previous residue-pair features
# architecture design inspired by CopulaNet paper
# - 2) MSA2MSA: process MSA features using Transformer (or Performer) encoder. (Attention over L first followed by attention over N)
# - 3) Pair2MSA: Update MSA features using pair feature
# - 4) Pair2Pair: process pair features using Transformer (or Performer) encoder.
def make_graph(xyz, pair, idx, top_k=64, kmin=9):
'''
Input:
- xyz: current backbone cooordinates (B, L, 3, 3)
- pair: pair features from Trunk (B, L, L, E)
- idx: residue index from ground truth pdb
Output:
- G: defined graph
'''
B, L = xyz.shape[:2]
device = xyz.device
# distance map from current CA coordinates
D = torch.cdist(xyz[:,:,1,:], xyz[:,:,1,:]) + torch.eye(L, device=device).unsqueeze(0)*999.9 # (B, L, L)
# seq sep
sep = idx[:,None,:] - idx[:,:,None]
sep = sep.abs() + torch.eye(L, device=device).unsqueeze(0)*999.9
# get top_k neighbors
D_neigh, E_idx = torch.topk(D, min(top_k, L), largest=False) # shape of E_idx: (B, L, top_k)
topk_matrix = torch.zeros((B, L, L), device=device)
topk_matrix.scatter_(2, E_idx, 1.0)
# put an edge if any of the 3 conditions are met:
# 1) |i-j| <= kmin (connect sequentially adjacent residues)
# 2) top_k neighbors
cond = torch.logical_or(topk_matrix > 0.0, sep < kmin)
b,i,j = torch.where(cond)
src = b*L+i
tgt = b*L+j
G = dgl.graph((src, tgt), num_nodes=B*L).to(device)
G.edata['d'] = (xyz[b,j,1,:] - xyz[b,i,1,:]).detach() # no gradient through basis function
G.edata['w'] = pair[b,i,j]
return G
def get_bonded_neigh(idx):
'''
Input:
- idx: residue indices of given sequence (B,L)
Output:
- neighbor: bonded neighbor information with sign (B, L, L, 1)
'''
neighbor = idx[:,None,:] - idx[:,:,None]
neighbor = neighbor.float()
sign = torch.sign(neighbor) # (B, L, L)
neighbor = torch.abs(neighbor)
neighbor[neighbor > 1] = 0.0
neighbor = sign * neighbor
return neighbor.unsqueeze(-1)
def rbf(D):
# Distance radial basis function
D_min, D_max, D_count = 0., 20., 36
D_mu = torch.linspace(D_min, D_max, D_count).to(D.device)
D_mu = D_mu[None,:]
D_sigma = (D_max - D_min) / D_count
D_expand = torch.unsqueeze(D, -1)
RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
return RBF
class CoevolExtractor(nn.Module):
def __init__(self, n_feat_proj, n_feat_out, p_drop=0.1):
super(CoevolExtractor, self).__init__()
self.norm_2d = LayerNorm(n_feat_proj*n_feat_proj)
# project down to output dimension (pair feature dimension)
self.proj_2 = nn.Linear(n_feat_proj**2, n_feat_out)
def forward(self, x_down, x_down_w):
B, N, L = x_down.shape[:3]
pair = torch.einsum('abij,ablm->ailjm', x_down, x_down_w) # outer-product & average pool
pair = pair.reshape(B, L, L, -1)
pair = self.norm_2d(pair)
pair = self.proj_2(pair) # (B, L, L, n_feat_out) # project down to pair dimension
return pair
class MSA2Pair(nn.Module):
def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32,
n_resblock=1, p_drop=0.1, n_att_head=8):
super(MSA2Pair, self).__init__()
# project down embedding dimension (n_feat --> n_feat_proj)
self.norm_1 = LayerNorm(n_feat)
self.proj_1 = nn.Linear(n_feat, n_feat_proj)
self.encoder = SequenceWeight(n_feat_proj, 1, dropout=p_drop)
self.coevol = CoevolExtractor(n_feat_proj, n_feat_out)
# ResNet to update pair features
self.norm_down = LayerNorm(n_feat_proj)
self.norm_orig = LayerNorm(n_feat_out)
self.norm_new = LayerNorm(n_feat_out)
self.update = ResidualNetwork(n_resblock, n_feat_out*2+n_feat_proj*4+n_att_head, n_feat_out, n_feat_out, p_drop=p_drop)
def forward(self, msa, pair_orig, att):
# Input: MSA embeddings (B, N, L, K), original pair embeddings (B, L, L, C)
# Output: updated pair info (B, L, L, C)
B, N, L, _ = msa.shape
# project down to reduce memory
msa = self.norm_1(msa)
x_down = self.proj_1(msa) # (B, N, L, n_feat_proj)
# get sequence weight
x_down = self.norm_down(x_down)
w_seq = self.encoder(x_down).reshape(B, L, 1, N).permute(0,3,1,2)
feat_1d = w_seq*x_down
pair = self.coevol(x_down, feat_1d)
# average pooling over N of given MSA info
feat_1d = feat_1d.sum(1)
# query sequence info
query = x_down[:,0] # (B,L,K)
feat_1d = torch.cat((feat_1d, query), dim=-1) # additional 1D features
# tile 1D features
left = feat_1d.unsqueeze(2).repeat(1, 1, L, 1)
right = feat_1d.unsqueeze(1).repeat(1, L, 1, 1)
# update original pair features through convolutions after concat
pair_orig = self.norm_orig(pair_orig)
pair = self.norm_new(pair)
pair = torch.cat((pair_orig, pair, left, right, att), -1)
pair = pair.permute(0,3,1,2).contiguous() # prep for convolution layer
pair = self.update(pair)
pair = pair.permute(0,2,3,1).contiguous() # (B, L, L, C)
return pair
class MSA2MSA(nn.Module):
def __init__(self, n_layer=1, n_att_head=8, n_feat=256, r_ff=4, p_drop=0.1,
performer_N_opts=None, performer_L_opts=None):
super(MSA2MSA, self).__init__()
# attention along L
enc_layer_1 = EncoderLayer(d_model=n_feat, d_ff=n_feat*r_ff,
heads=n_att_head, p_drop=p_drop,
use_tied=True)
#performer_opts=performer_L_opts)
self.encoder_1 = Encoder(enc_layer_1, n_layer)
# attention along N
enc_layer_2 = EncoderLayer(d_model=n_feat, d_ff=n_feat*r_ff,
heads=n_att_head, p_drop=p_drop,
performer_opts=performer_N_opts)
self.encoder_2 = Encoder(enc_layer_2, n_layer)
def forward(self, x):
# Input: MSA embeddings (B, N, L, K)
# Output: updated MSA embeddings (B, N, L, K)
B, N, L, _ = x.shape
# attention along L
x, att = self.encoder_1(x, return_att=True)
# attention along N
x = x.permute(0,2,1,3).contiguous()
x = self.encoder_2(x)
x = x.permute(0,2,1,3).contiguous()
return x, att
class Pair2MSA(nn.Module):
def __init__(self, n_layer=1, n_att_head=4, n_feat_in=128, n_feat_out=256, r_ff=4, p_drop=0.1):
super(Pair2MSA, self).__init__()
enc_layer = DirectEncoderLayer(heads=n_att_head, \
d_in=n_feat_in, d_out=n_feat_out,\
d_ff=n_feat_out*r_ff,\
p_drop=p_drop)
self.encoder = CrossEncoder(enc_layer, n_layer)
def forward(self, pair, msa):
out = self.encoder(pair, msa) # (B, N, L, K)
return out
class Pair2Pair(nn.Module):
def __init__(self, n_layer=1, n_att_head=8, n_feat=128, r_ff=4, p_drop=0.1,
performer_L_opts=None):
super(Pair2Pair, self).__init__()
enc_layer = AxialEncoderLayer(d_model=n_feat, d_ff=n_feat*r_ff,
heads=n_att_head, p_drop=p_drop,
performer_opts=performer_L_opts)
self.encoder = Encoder(enc_layer, n_layer)
def forward(self, x):
return self.encoder(x)
class Str2Str(nn.Module):
def __init__(self, d_msa=64, d_pair=128,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, p_drop=0.1):
super(Str2Str, self).__init__()
# initial node & pair feature process
self.norm_msa = LayerNorm(d_msa)
self.norm_pair = LayerNorm(d_pair)
self.encoder_seq = SequenceWeight(d_msa, 1, dropout=p_drop)
self.embed_x = nn.Linear(d_msa+21, SE3_param['l0_in_features'])
self.embed_e = nn.Linear(d_pair, SE3_param['num_edge_features'])
self.norm_node = LayerNorm(SE3_param['l0_in_features'])
self.norm_edge = LayerNorm(SE3_param['num_edge_features'])
self.se3 = SE3Transformer(**SE3_param)
@torch.cuda.amp.autocast(enabled=False)
def forward(self, msa, pair, xyz, seq1hot, idx, top_k=64):
# process msa & pair features
B, N, L = msa.shape[:3]
msa = self.norm_msa(msa)
pair = self.norm_pair(pair)
w_seq = self.encoder_seq(msa).reshape(B, L, 1, N).permute(0,3,1,2)
msa = w_seq*msa
msa = msa.sum(dim=1)
msa = torch.cat((msa, seq1hot), dim=-1)
msa = self.norm_node(self.embed_x(msa))
pair = self.norm_edge(self.embed_e(pair))
# define graph
G = make_graph(xyz, pair, idx, top_k=top_k)
l1_feats = xyz - xyz[:,:,1,:].unsqueeze(2) # l1 features = displacement vector to CA
l1_feats = l1_feats.reshape(B*L, -1, 3)
# apply SE(3) Transformer & update coordinates
shift = self.se3(G, msa.reshape(B*L, -1, 1), l1_feats)
state = shift['0'].reshape(B, L, -1) # (B, L, C)
offset = shift['1'].reshape(B, L, -1, 3) # (B, L, 3, 3)
CA_new = xyz[:,:,1] + offset[:,:,1]
N_new = CA_new + offset[:,:,0]
C_new = CA_new + offset[:,:,2]
xyz_new = torch.stack([N_new, CA_new, C_new], dim=2)
return xyz_new, state
class Str2MSA(nn.Module):
def __init__(self, d_msa=64, d_state=32, inner_dim=32, r_ff=4,
distbin=[8.0, 12.0, 16.0, 20.0], p_drop=0.1):
super(Str2MSA, self).__init__()
self.distbin = distbin
n_att_head = len(distbin)
self.norm_state = LayerNorm(d_state)
self.norm1 = LayerNorm(d_msa)
self.attn = MaskedDirectMultiheadAttention(d_state, d_msa, n_att_head, d_k=inner_dim, dropout=p_drop)
self.dropout1 = nn.Dropout(p_drop,inplace=True)
self.norm2 = LayerNorm(d_msa)
self.ff = FeedForwardLayer(d_msa, d_msa*r_ff, p_drop=p_drop)
self.dropout2 = nn.Dropout(p_drop,inplace=True)
def forward(self, msa, xyz, state):
dist = torch.cdist(xyz[:,:,1], xyz[:,:,1]) # (B, L, L)
mask_s = list()
for distbin in self.distbin:
mask_s.append(1.0 - torch.sigmoid(dist-distbin))
mask_s = torch.stack(mask_s, dim=1) # (B, h, L, L)
state = self.norm_state(state)
msa2 = self.norm1(msa)
msa2 = self.attn(state, state, msa2, mask_s)
msa = msa + self.dropout1(msa2)
msa2 = self.norm2(msa)
msa2 = self.ff(msa2)
msa = msa + self.dropout2(msa2)
return msa
class IterBlock(nn.Module):
def __init__(self, n_layer=1, d_msa=64, d_pair=128, n_head_msa=4, n_head_pair=8, r_ff=4,
n_resblock=1, p_drop=0.1, performer_L_opts=None, performer_N_opts=None):
super(IterBlock, self).__init__()
self.msa2msa = MSA2MSA(n_layer=n_layer, n_att_head=n_head_msa, n_feat=d_msa,
r_ff=r_ff, p_drop=p_drop,
performer_N_opts=performer_N_opts,
performer_L_opts=performer_L_opts)
self.msa2pair = MSA2Pair(n_feat=d_msa, n_feat_out=d_pair, n_feat_proj=32,
n_resblock=n_resblock, p_drop=p_drop, n_att_head=n_head_msa)
self.pair2pair = Pair2Pair(n_layer=n_layer, n_att_head=n_head_pair,
n_feat=d_pair, r_ff=r_ff, p_drop=p_drop,
performer_L_opts=performer_L_opts)
self.pair2msa = Pair2MSA(n_layer=n_layer, n_att_head=4,
n_feat_in=d_pair, n_feat_out=d_msa, r_ff=r_ff, p_drop=p_drop)
def forward(self, msa, pair):
# input:
# msa: initial MSA embeddings (N, L, d_msa)
# pair: initial residue pair embeddings (L, L, d_pair)
# 1. process MSA features
msa, att = self.msa2msa(msa)
# 2. update pair features using given MSA
pair = self.msa2pair(msa, pair, att)
# 3. process pair features
pair = self.pair2pair(pair)
# 4. update MSA features using updated pair features
msa = self.pair2msa(pair, msa)
return msa, pair
class IterBlock_w_Str(nn.Module):
def __init__(self, n_layer=1, d_msa=64, d_pair=128, n_head_msa=4, n_head_pair=8, r_ff=4,
n_resblock=1, p_drop=0.1, performer_L_opts=None, performer_N_opts=None,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}):
super(IterBlock_w_Str, self).__init__()
self.msa2msa = MSA2MSA(n_layer=n_layer, n_att_head=n_head_msa, n_feat=d_msa,
r_ff=r_ff, p_drop=p_drop,
performer_N_opts=performer_N_opts,
performer_L_opts=performer_L_opts)
self.msa2pair = MSA2Pair(n_feat=d_msa, n_feat_out=d_pair, n_feat_proj=32,
n_resblock=n_resblock, p_drop=p_drop, n_att_head=n_head_msa)
self.pair2pair = Pair2Pair(n_layer=n_layer, n_att_head=n_head_pair,
n_feat=d_pair, r_ff=r_ff, p_drop=p_drop,
performer_L_opts=performer_L_opts)
self.pair2msa = Pair2MSA(n_layer=n_layer, n_att_head=4,
n_feat_in=d_pair, n_feat_out=d_msa, r_ff=r_ff, p_drop=p_drop)
self.str2str = Str2Str(d_msa=d_msa, d_pair=d_pair, SE3_param=SE3_param, p_drop=p_drop)
self.str2msa = Str2MSA(d_msa=d_msa, d_state=SE3_param['l0_out_features'],
r_ff=r_ff, p_drop=p_drop)
def forward(self, msa, pair, xyz, seq1hot, idx, top_k=64):
# input:
# msa: initial MSA embeddings (N, L, d_msa)
# pair: initial residue pair embeddings (L, L, d_pair)
# 1. process MSA features
msa, att = self.msa2msa(msa)
# 2. update pair features using given MSA
pair = self.msa2pair(msa, pair, att)
# 3. process pair features
pair = self.pair2pair(pair)
# 4. update MSA features using updated pair features
msa = self.pair2msa(pair, msa)
xyz, state = self.str2str(msa.float(), pair.float(), xyz.float(), seq1hot, idx, top_k=top_k)
msa = self.str2msa(msa, xyz, state)
return msa, pair, xyz
class FinalBlock(nn.Module):
def __init__(self, n_layer=1, d_msa=64, d_pair=128, n_head_msa=4, n_head_pair=8, r_ff=4,
n_resblock=1, p_drop=0.1, performer_L_opts=None, performer_N_opts=None,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}):
super(FinalBlock, self).__init__()
self.msa2msa = MSA2MSA(n_layer=n_layer, n_att_head=n_head_msa, n_feat=d_msa,
r_ff=r_ff, p_drop=p_drop,
performer_N_opts=performer_N_opts,
performer_L_opts=performer_L_opts)
self.msa2pair = MSA2Pair(n_feat=d_msa, n_feat_out=d_pair, n_feat_proj=32,
n_resblock=n_resblock, p_drop=p_drop, n_att_head=n_head_msa)
self.pair2pair = Pair2Pair(n_layer=n_layer, n_att_head=n_head_pair,
n_feat=d_pair, r_ff=r_ff, p_drop=p_drop,
performer_L_opts=performer_L_opts)
self.pair2msa = Pair2MSA(n_layer=n_layer, n_att_head=4,
n_feat_in=d_pair, n_feat_out=d_msa, r_ff=r_ff, p_drop=p_drop)
self.str2str = Str2Str(d_msa=d_msa, d_pair=d_pair, SE3_param=SE3_param, p_drop=p_drop)
self.norm_state = LayerNorm(SE3_param['l0_out_features'])
self.pred_lddt = nn.Linear(SE3_param['l0_out_features'], 1)
def forward(self, msa, pair, xyz, seq1hot, idx):
# input:
# msa: initial MSA embeddings (N, L, d_msa)
# pair: initial residue pair embeddings (L, L, d_pair)
# 1. process MSA features
msa, att = self.msa2msa(msa)
# 2. update pair features using given MSA
pair = self.msa2pair(msa, pair, att)
# 3. process pair features
pair = self.pair2pair(pair)
msa = self.pair2msa(pair, msa)
xyz, state = self.str2str(msa.float(), pair.float(), xyz.float(), seq1hot, idx, top_k=32)
lddt = self.pred_lddt(self.norm_state(state))
return msa, pair, xyz, lddt.squeeze(-1)
class IterativeFeatureExtractor(nn.Module):
def __init__(self, n_module=4, n_module_str=4, n_layer=4, d_msa=256, d_pair=128, d_hidden=64,
n_head_msa=8, n_head_pair=8, r_ff=4,
n_resblock=1, p_drop=0.1,
performer_L_opts=None, performer_N_opts=None,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}):
super(IterativeFeatureExtractor, self).__init__()
self.n_module = n_module
self.n_module_str = n_module_str
#
self.initial = Pair2Pair(n_layer=n_layer, n_att_head=n_head_pair,
n_feat=d_pair, r_ff=r_ff, p_drop=p_drop,
performer_L_opts=performer_L_opts)
if self.n_module > 0:
self.iter_block_1 = _get_clones(IterBlock(n_layer=n_layer,
d_msa=d_msa, d_pair=d_pair,
n_head_msa=n_head_msa,
n_head_pair=n_head_pair,
r_ff=r_ff,
n_resblock=n_resblock,
p_drop=p_drop,
performer_N_opts=performer_N_opts,
performer_L_opts=performer_L_opts
), n_module)
self.init_str = InitStr_Network(node_dim_in=d_msa, node_dim_hidden=d_hidden,
edge_dim_in=d_pair, edge_dim_hidden=d_hidden,
nheads=4, nblocks=3, dropout=p_drop)
if self.n_module_str > 0:
self.iter_block_2 = _get_clones(IterBlock_w_Str(n_layer=n_layer,
d_msa=d_msa, d_pair=d_pair,
n_head_msa=n_head_msa,
n_head_pair=n_head_pair,
r_ff=r_ff,
n_resblock=n_resblock,
p_drop=p_drop,
performer_N_opts=performer_N_opts,
performer_L_opts=performer_L_opts,
SE3_param=SE3_param
), n_module_str)
self.final = FinalBlock(n_layer=n_layer, d_msa=d_msa, d_pair=d_pair,
n_head_msa=n_head_msa, n_head_pair=n_head_pair, r_ff=r_ff,
n_resblock=n_resblock, p_drop=p_drop,
performer_L_opts=performer_L_opts, performer_N_opts=performer_N_opts,
SE3_param=SE3_param)
def forward(self, msa, pair, seq1hot, idx):
# input:
# msa: initial MSA embeddings (N, L, d_msa)
# pair: initial residue pair embeddings (L, L, d_pair)
pair_s = list()
pair = self.initial(pair)
if self.n_module > 0:
for i_m in range(self.n_module):
# extract features from MSA & update original pair features
msa, pair = self.iter_block_1[i_m](msa, pair)
xyz = self.init_str(seq1hot, idx, msa, pair)
top_ks = [128, 128, 64, 64]
if self.n_module_str > 0:
for i_m in range(self.n_module_str):
msa, pair, xyz = self.iter_block_2[i_m](msa, pair, xyz, seq1hot, idx, top_k=top_ks[i_m])
msa, pair, xyz, lddt = self.final(msa, pair, xyz, seq1hot, idx)
return msa[:,0], pair, xyz, lddt