176 lines
7.6 KiB
Python
176 lines
7.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.utils.checkpoint as checkpoint
|
|
from Transformer import LayerNorm
|
|
from InitStrGenerator import make_graph
|
|
from InitStrGenerator import get_seqsep, UniMPBlock
|
|
from Attention_module_w_str import make_graph as make_graph_topk
|
|
from Attention_module_w_str import get_bonded_neigh, rbf
|
|
from SE3_network import SE3Transformer
|
|
from Transformer import _get_clones, create_custom_forward
|
|
# Re-generate initial coordinates based on 1) final pair features 2) predicted distogram
|
|
# Then, refine it through multiple SE3 transformer block
|
|
|
|
class Regen_Network(nn.Module):
|
|
def __init__(self,
|
|
node_dim_in=64,
|
|
node_dim_hidden=64,
|
|
edge_dim_in=128,
|
|
edge_dim_hidden=64,
|
|
state_dim=8,
|
|
nheads=4,
|
|
nblocks=3,
|
|
dropout=0.0):
|
|
super(Regen_Network, self).__init__()
|
|
|
|
# embedding layers for node and edge features
|
|
self.norm_node = LayerNorm(node_dim_in)
|
|
self.norm_edge = LayerNorm(edge_dim_in)
|
|
|
|
self.embed_x = nn.Sequential(nn.Linear(node_dim_in+21, node_dim_hidden), LayerNorm(node_dim_hidden))
|
|
self.embed_e = nn.Sequential(nn.Linear(edge_dim_in+2, edge_dim_hidden), LayerNorm(edge_dim_hidden))
|
|
|
|
# graph transformer
|
|
blocks = [UniMPBlock(node_dim_hidden,edge_dim_hidden,nheads,dropout) for _ in range(nblocks)]
|
|
self.transformer = nn.Sequential(*blocks)
|
|
|
|
# outputs
|
|
self.get_xyz = nn.Linear(node_dim_hidden,9)
|
|
self.norm_state = LayerNorm(node_dim_hidden)
|
|
self.get_state = nn.Linear(node_dim_hidden, state_dim)
|
|
|
|
def forward(self, seq1hot, idx, node, edge):
|
|
B, L = node.shape[:2]
|
|
node = self.norm_node(node)
|
|
edge = self.norm_edge(edge)
|
|
|
|
node = torch.cat((node, seq1hot), dim=-1)
|
|
node = self.embed_x(node)
|
|
|
|
seqsep = get_seqsep(idx)
|
|
neighbor = get_bonded_neigh(idx)
|
|
edge = torch.cat((edge, seqsep, neighbor), dim=-1)
|
|
edge = self.embed_e(edge)
|
|
|
|
G = make_graph(node, idx, edge)
|
|
Gout = self.transformer(G)
|
|
|
|
xyz = self.get_xyz(Gout.x)
|
|
state = self.get_state(self.norm_state(Gout.x))
|
|
return xyz.reshape(B, L, 3, 3) , state.reshape(B, L, -1)
|
|
|
|
class Refine_Network(nn.Module):
|
|
def __init__(self, d_node=64, d_pair=128, d_state=16,
|
|
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, p_drop=0.0):
|
|
super(Refine_Network, self).__init__()
|
|
self.norm_msa = LayerNorm(d_node)
|
|
self.norm_pair = LayerNorm(d_pair)
|
|
self.norm_state = LayerNorm(d_state)
|
|
|
|
self.embed_x = nn.Linear(d_node+21+d_state, SE3_param['l0_in_features'])
|
|
self.embed_e1 = nn.Linear(d_pair, SE3_param['num_edge_features'])
|
|
self.embed_e2 = nn.Linear(SE3_param['num_edge_features']+36+1, SE3_param['num_edge_features'])
|
|
|
|
self.norm_node = LayerNorm(SE3_param['l0_in_features'])
|
|
self.norm_edge1 = LayerNorm(SE3_param['num_edge_features'])
|
|
self.norm_edge2 = LayerNorm(SE3_param['num_edge_features'])
|
|
|
|
self.se3 = SE3Transformer(**SE3_param)
|
|
|
|
@torch.cuda.amp.autocast(enabled=False)
|
|
def forward(self, msa, pair, xyz, state, seq1hot, idx, top_k=64):
|
|
# process node & pair features
|
|
B, L = msa.shape[:2]
|
|
node = self.norm_msa(msa)
|
|
pair = self.norm_pair(pair)
|
|
state = self.norm_state(state)
|
|
|
|
node = torch.cat((node, seq1hot, state), dim=-1)
|
|
node = self.norm_node(self.embed_x(node))
|
|
pair = self.norm_edge1(self.embed_e1(pair))
|
|
|
|
neighbor = get_bonded_neigh(idx)
|
|
rbf_feat = rbf(torch.cdist(xyz[:,:,1,:], xyz[:,:,1,:]))
|
|
pair = torch.cat((pair, rbf_feat, neighbor), dim=-1)
|
|
pair = self.norm_edge2(self.embed_e2(pair))
|
|
|
|
# define graph
|
|
G = make_graph_topk(xyz, pair, idx, top_k=top_k)
|
|
l1_feats = xyz - xyz[:,:,1,:].unsqueeze(2) # l1 features = displacement vector to CA
|
|
l1_feats = l1_feats.reshape(B*L, -1, 3)
|
|
|
|
# apply SE(3) Transformer & update coordinates
|
|
shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats)
|
|
|
|
state = shift['0'].reshape(B, L, -1) # (B, L, C)
|
|
|
|
offset = shift['1'].reshape(B, L, -1, 3) # (B, L, 3, 3)
|
|
CA_new = xyz[:,:,1] + offset[:,:,1]
|
|
N_new = CA_new + offset[:,:,0]
|
|
C_new = CA_new + offset[:,:,2]
|
|
xyz_new = torch.stack([N_new, CA_new, C_new], dim=2)
|
|
|
|
return xyz_new, state
|
|
|
|
class Refine_module(nn.Module):
|
|
def __init__(self, n_module, d_node=64, d_node_hidden=64, d_pair=128, d_pair_hidden=64,
|
|
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, p_drop=0.0):
|
|
super(Refine_module, self).__init__()
|
|
self.n_module = n_module
|
|
self.proj_edge = nn.Linear(d_pair, d_pair_hidden*2)
|
|
|
|
self.regen_net = Regen_Network(node_dim_in=d_node, node_dim_hidden=d_node_hidden,
|
|
edge_dim_in=d_pair_hidden*2, edge_dim_hidden=d_pair_hidden,
|
|
state_dim=SE3_param['l0_out_features'],
|
|
nheads=4, nblocks=3, dropout=p_drop)
|
|
self.refine_net = _get_clones(Refine_Network(d_node=d_node, d_pair=d_pair_hidden*2,
|
|
d_state=SE3_param['l0_out_features'],
|
|
SE3_param=SE3_param, p_drop=p_drop), self.n_module)
|
|
self.norm_state = LayerNorm(SE3_param['l0_out_features'])
|
|
self.pred_lddt = nn.Linear(SE3_param['l0_out_features'], 1)
|
|
|
|
def forward(self, node, edge, seq1hot, idx, use_transf_checkpoint=False, eps=1e-4):
|
|
edge = self.proj_edge(edge)
|
|
|
|
xyz, state = self.regen_net(seq1hot, idx, node, edge)
|
|
|
|
# DOUBLE IT w/ Mirror images
|
|
xyz = torch.cat([xyz, xyz*torch.tensor([1,1,-1], dtype=xyz.dtype, device=xyz.device)])
|
|
state = torch.cat([state, state])
|
|
node = torch.cat([node, node])
|
|
edge = torch.cat([edge, edge])
|
|
idx = torch.cat([idx, idx])
|
|
seq1hot = torch.cat([seq1hot, seq1hot])
|
|
|
|
best_xyz = xyz
|
|
best_lddt = torch.zeros((xyz.shape[0], xyz.shape[1], 1), device=xyz.device)
|
|
prev_lddt = 0.0
|
|
no_impr = 0
|
|
no_impr_best = 0
|
|
for i_iter in range(200):
|
|
for i_m in range(self.n_module):
|
|
if use_transf_checkpoint:
|
|
xyz, state = checkpoint.checkpoint(create_custom_forward(self.refine_net[i_m], top_k=64), node.float(), edge.float(), xyz.detach().float(), state.float(), seq1hot, idx)
|
|
else:
|
|
xyz, state = self.refine_net[i_m](node.float(), edge.float(), xyz.detach().float(), state.float(), seq1hot, idx, top_k=64)
|
|
#
|
|
lddt = self.pred_lddt(self.norm_state(state))
|
|
lddt = torch.clamp(lddt, 0.0, 1.0)[...,0]
|
|
print (f"SE(3) iteration {i_iter} {lddt.mean(-1).cpu().numpy()}")
|
|
if lddt.mean(-1).max() <= prev_lddt+eps:
|
|
no_impr += 1
|
|
else:
|
|
no_impr = 0
|
|
if lddt.mean(-1).max() <= best_lddt.mean(-1).max()+eps:
|
|
no_impr_best += 1
|
|
else:
|
|
no_impr_best = 0
|
|
if no_impr > 10 or no_impr_best > 20:
|
|
break
|
|
if lddt.mean(-1).max() > best_lddt.mean(-1).max():
|
|
best_lddt = lddt
|
|
best_xyz = xyz
|
|
prev_lddt = lddt.mean(-1).max()
|
|
pick = best_lddt.mean(-1).argmax()
|
|
return best_xyz[pick][None], best_lddt[pick][None]
|