DeepLearningExamples/DGLPyTorch/DrugDiscovery/RoseTTAFold/network/Refine_module.py
2021-10-15 15:46:41 +02:00

176 lines
7.6 KiB
Python

import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from Transformer import LayerNorm
from InitStrGenerator import make_graph
from InitStrGenerator import get_seqsep, UniMPBlock
from Attention_module_w_str import make_graph as make_graph_topk
from Attention_module_w_str import get_bonded_neigh, rbf
from SE3_network import SE3Transformer
from Transformer import _get_clones, create_custom_forward
# Re-generate initial coordinates based on 1) final pair features 2) predicted distogram
# Then, refine it through multiple SE3 transformer block
class Regen_Network(nn.Module):
def __init__(self,
node_dim_in=64,
node_dim_hidden=64,
edge_dim_in=128,
edge_dim_hidden=64,
state_dim=8,
nheads=4,
nblocks=3,
dropout=0.0):
super(Regen_Network, self).__init__()
# embedding layers for node and edge features
self.norm_node = LayerNorm(node_dim_in)
self.norm_edge = LayerNorm(edge_dim_in)
self.embed_x = nn.Sequential(nn.Linear(node_dim_in+21, node_dim_hidden), LayerNorm(node_dim_hidden))
self.embed_e = nn.Sequential(nn.Linear(edge_dim_in+2, edge_dim_hidden), LayerNorm(edge_dim_hidden))
# graph transformer
blocks = [UniMPBlock(node_dim_hidden,edge_dim_hidden,nheads,dropout) for _ in range(nblocks)]
self.transformer = nn.Sequential(*blocks)
# outputs
self.get_xyz = nn.Linear(node_dim_hidden,9)
self.norm_state = LayerNorm(node_dim_hidden)
self.get_state = nn.Linear(node_dim_hidden, state_dim)
def forward(self, seq1hot, idx, node, edge):
B, L = node.shape[:2]
node = self.norm_node(node)
edge = self.norm_edge(edge)
node = torch.cat((node, seq1hot), dim=-1)
node = self.embed_x(node)
seqsep = get_seqsep(idx)
neighbor = get_bonded_neigh(idx)
edge = torch.cat((edge, seqsep, neighbor), dim=-1)
edge = self.embed_e(edge)
G = make_graph(node, idx, edge)
Gout = self.transformer(G)
xyz = self.get_xyz(Gout.x)
state = self.get_state(self.norm_state(Gout.x))
return xyz.reshape(B, L, 3, 3) , state.reshape(B, L, -1)
class Refine_Network(nn.Module):
def __init__(self, d_node=64, d_pair=128, d_state=16,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, p_drop=0.0):
super(Refine_Network, self).__init__()
self.norm_msa = LayerNorm(d_node)
self.norm_pair = LayerNorm(d_pair)
self.norm_state = LayerNorm(d_state)
self.embed_x = nn.Linear(d_node+21+d_state, SE3_param['l0_in_features'])
self.embed_e1 = nn.Linear(d_pair, SE3_param['num_edge_features'])
self.embed_e2 = nn.Linear(SE3_param['num_edge_features']+36+1, SE3_param['num_edge_features'])
self.norm_node = LayerNorm(SE3_param['l0_in_features'])
self.norm_edge1 = LayerNorm(SE3_param['num_edge_features'])
self.norm_edge2 = LayerNorm(SE3_param['num_edge_features'])
self.se3 = SE3Transformer(**SE3_param)
@torch.cuda.amp.autocast(enabled=False)
def forward(self, msa, pair, xyz, state, seq1hot, idx, top_k=64):
# process node & pair features
B, L = msa.shape[:2]
node = self.norm_msa(msa)
pair = self.norm_pair(pair)
state = self.norm_state(state)
node = torch.cat((node, seq1hot, state), dim=-1)
node = self.norm_node(self.embed_x(node))
pair = self.norm_edge1(self.embed_e1(pair))
neighbor = get_bonded_neigh(idx)
rbf_feat = rbf(torch.cdist(xyz[:,:,1,:], xyz[:,:,1,:]))
pair = torch.cat((pair, rbf_feat, neighbor), dim=-1)
pair = self.norm_edge2(self.embed_e2(pair))
# define graph
G = make_graph_topk(xyz, pair, idx, top_k=top_k)
l1_feats = xyz - xyz[:,:,1,:].unsqueeze(2) # l1 features = displacement vector to CA
l1_feats = l1_feats.reshape(B*L, -1, 3)
# apply SE(3) Transformer & update coordinates
shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats)
state = shift['0'].reshape(B, L, -1) # (B, L, C)
offset = shift['1'].reshape(B, L, -1, 3) # (B, L, 3, 3)
CA_new = xyz[:,:,1] + offset[:,:,1]
N_new = CA_new + offset[:,:,0]
C_new = CA_new + offset[:,:,2]
xyz_new = torch.stack([N_new, CA_new, C_new], dim=2)
return xyz_new, state
class Refine_module(nn.Module):
def __init__(self, n_module, d_node=64, d_node_hidden=64, d_pair=128, d_pair_hidden=64,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, p_drop=0.0):
super(Refine_module, self).__init__()
self.n_module = n_module
self.proj_edge = nn.Linear(d_pair, d_pair_hidden*2)
self.regen_net = Regen_Network(node_dim_in=d_node, node_dim_hidden=d_node_hidden,
edge_dim_in=d_pair_hidden*2, edge_dim_hidden=d_pair_hidden,
state_dim=SE3_param['l0_out_features'],
nheads=4, nblocks=3, dropout=p_drop)
self.refine_net = _get_clones(Refine_Network(d_node=d_node, d_pair=d_pair_hidden*2,
d_state=SE3_param['l0_out_features'],
SE3_param=SE3_param, p_drop=p_drop), self.n_module)
self.norm_state = LayerNorm(SE3_param['l0_out_features'])
self.pred_lddt = nn.Linear(SE3_param['l0_out_features'], 1)
def forward(self, node, edge, seq1hot, idx, use_transf_checkpoint=False, eps=1e-4):
edge = self.proj_edge(edge)
xyz, state = self.regen_net(seq1hot, idx, node, edge)
# DOUBLE IT w/ Mirror images
xyz = torch.cat([xyz, xyz*torch.tensor([1,1,-1], dtype=xyz.dtype, device=xyz.device)])
state = torch.cat([state, state])
node = torch.cat([node, node])
edge = torch.cat([edge, edge])
idx = torch.cat([idx, idx])
seq1hot = torch.cat([seq1hot, seq1hot])
best_xyz = xyz
best_lddt = torch.zeros((xyz.shape[0], xyz.shape[1], 1), device=xyz.device)
prev_lddt = 0.0
no_impr = 0
no_impr_best = 0
for i_iter in range(200):
for i_m in range(self.n_module):
if use_transf_checkpoint:
xyz, state = checkpoint.checkpoint(create_custom_forward(self.refine_net[i_m], top_k=64), node.float(), edge.float(), xyz.detach().float(), state.float(), seq1hot, idx)
else:
xyz, state = self.refine_net[i_m](node.float(), edge.float(), xyz.detach().float(), state.float(), seq1hot, idx, top_k=64)
#
lddt = self.pred_lddt(self.norm_state(state))
lddt = torch.clamp(lddt, 0.0, 1.0)[...,0]
print (f"SE(3) iteration {i_iter} {lddt.mean(-1).cpu().numpy()}")
if lddt.mean(-1).max() <= prev_lddt+eps:
no_impr += 1
else:
no_impr = 0
if lddt.mean(-1).max() <= best_lddt.mean(-1).max()+eps:
no_impr_best += 1
else:
no_impr_best = 0
if no_impr > 10 or no_impr_best > 20:
break
if lddt.mean(-1).max() > best_lddt.mean(-1).max():
best_lddt = lddt
best_xyz = xyz
prev_lddt = lddt.mean(-1).max()
pick = best_lddt.mean(-1).argmax()
return best_xyz[pick][None], best_lddt[pick][None]