DeepLearningExamples/DGLPyTorch/DrugDiscovery/RoseTTAFold/network/predict_complex.py
2021-10-15 15:46:41 +02:00

318 lines
13 KiB
Python

import sys, os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
from parsers import parse_a3m, read_templates
from RoseTTAFoldModel import RoseTTAFoldModule_e2e
import util
from collections import namedtuple
from ffindex import *
from kinematics import xyz_to_c6d, c6d_to_bins2, xyz_to_t2d
from trFold import TRFold
script_dir = '/'.join(os.path.dirname(os.path.realpath(__file__)).split('/')[:-1])
NBIN = [37, 37, 37, 19]
MODEL_PARAM ={
"n_module" : 8,
"n_module_str" : 4,
"n_module_ref" : 4,
"n_layer" : 1,
"d_msa" : 384 ,
"d_pair" : 288,
"d_templ" : 64,
"n_head_msa" : 12,
"n_head_pair" : 8,
"n_head_templ" : 4,
"d_hidden" : 64,
"r_ff" : 4,
"n_resblock" : 1,
"p_drop" : 0.1,
"use_templ" : True,
"performer_N_opts": {"nb_features": 64},
"performer_L_opts": {"nb_features": 64}
}
SE3_param = {
"num_layers" : 2,
"num_channels" : 16,
"num_degrees" : 2,
"l0_in_features": 32,
"l0_out_features": 8,
"l1_in_features": 3,
"l1_out_features": 3,
"num_edge_features": 32,
"div": 2,
"n_heads": 4
}
REF_param = {
"num_layers" : 3,
"num_channels" : 32,
"num_degrees" : 3,
"l0_in_features": 32,
"l0_out_features": 8,
"l1_in_features": 3,
"l1_out_features": 3,
"num_edge_features": 32,
"div": 4,
"n_heads": 4
}
MODEL_PARAM['SE3_param'] = SE3_param
MODEL_PARAM['REF_param'] = REF_param
# params for the folding protocol
fold_params = {
"SG7" : np.array([[[-2,3,6,7,6,3,-2]]])/21,
"SG9" : np.array([[[-21,14,39,54,59,54,39,14,-21]]])/231,
"DCUT" : 19.5,
"ALPHA" : 1.57,
# TODO: add Cb to the motif
"NCAC" : np.array([[-0.676, -1.294, 0. ],
[ 0. , 0. , 0. ],
[ 1.5 , -0.174, 0. ]], dtype=np.float32),
"CLASH" : 2.0,
"PCUT" : 0.5,
"DSTEP" : 0.5,
"ASTEP" : np.deg2rad(10.0),
"XYZRAD" : 7.5,
"WANG" : 0.1,
"WCST" : 0.1
}
fold_params["SG"] = fold_params["SG9"]
class Predictor():
def __init__(self, model_dir=None, use_cpu=False):
if model_dir == None:
self.model_dir = "%s/weights"%(script_dir)
else:
self.model_dir = model_dir
#
# define model name
self.model_name = "RoseTTAFold"
if torch.cuda.is_available() and (not use_cpu):
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.active_fn = nn.Softmax(dim=1)
# define model & load model
self.model = RoseTTAFoldModule_e2e(**MODEL_PARAM).to(self.device)
could_load = self.load_model(self.model_name)
if not could_load:
print ("ERROR: failed to load model")
sys.exit()
def load_model(self, model_name, suffix='e2e'):
chk_fn = "%s/%s_%s.pt"%(self.model_dir, model_name, suffix)
if not os.path.exists(chk_fn):
return False
checkpoint = torch.load(chk_fn, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'], strict=True)
return True
def predict(self, a3m_fn, out_prefix, Ls, templ_npz=None, window=1000, shift=100):
msa = parse_a3m(a3m_fn)
N, L = msa.shape
#
if templ_npz != None:
templ = np.load(templ_npz)
xyz_t = torch.from_numpy(templ["xyz_t"])
t1d = torch.from_numpy(templ["t1d"])
t0d = torch.from_numpy(templ["t0d"])
else:
xyz_t = torch.full((1, L, 3, 3), np.nan).float()
t1d = torch.zeros((1, L, 3)).float()
t0d = torch.zeros((1,3)).float()
self.model.eval()
with torch.no_grad():
#
msa = torch.tensor(msa).long().view(1, -1, L)
idx_pdb_orig = torch.arange(L).long().view(1, L)
idx_pdb = torch.arange(L).long().view(1, L)
L_prev = 0
for L_i in Ls[:-1]:
idx_pdb[:,L_prev+L_i:] += 500 # it was 200 originally.
L_prev += L_i
seq = msa[:,0]
#
# template features
xyz_t = xyz_t.float().unsqueeze(0)
t1d = t1d.float().unsqueeze(0)
t0d = t0d.float().unsqueeze(0)
t2d = xyz_to_t2d(xyz_t, t0d)
#
# do cropped prediction
if L > window*2:
prob_s = [np.zeros((L,L,NBIN[i]), dtype=np.float32) for i in range(4)]
count_1d = np.zeros((L,), dtype=np.float32)
count_2d = np.zeros((L,L), dtype=np.float32)
node_s = np.zeros((L,MODEL_PARAM['d_msa']), dtype=np.float32)
#
grids = np.arange(0, L-window+shift, shift)
ngrids = grids.shape[0]
print("ngrid: ", ngrids)
print("grids: ", grids)
print("windows: ", window)
for i in range(ngrids):
for j in range(i, ngrids):
start_1 = grids[i]
end_1 = min(grids[i]+window, L)
start_2 = grids[j]
end_2 = min(grids[j]+window, L)
sel = np.zeros((L)).astype(np.bool)
sel[start_1:end_1] = True
sel[start_2:end_2] = True
input_msa = msa[:,:,sel]
mask = torch.sum(input_msa==20, dim=-1) < 0.5*sel.sum() # remove too gappy sequences
input_msa = input_msa[mask].unsqueeze(0)
input_msa = input_msa[:,:1000].to(self.device)
input_idx = idx_pdb[:,sel].to(self.device)
input_idx_orig = idx_pdb_orig[:,sel]
input_seq = input_msa[:,0].to(self.device)
#
# Select template
input_t1d = t1d[:,:,sel].to(self.device) # (B, T, L, 3)
input_t2d = t2d[:,:,sel][:,:,:,sel].to(self.device)
#
print ("running crop: %d-%d/%d-%d"%(start_1, end_1, start_2, end_2), input_msa.shape)
with torch.cuda.amp.autocast():
logit_s, node, init_crds, pred_lddt = self.model(input_msa, input_seq, input_idx, t1d=input_t1d, t2d=input_t2d, return_raw=True)
#
# Not sure How can we merge init_crds.....
pred_lddt = torch.clamp(pred_lddt, 0.0, 1.0)
sub_idx = input_idx_orig[0]
sub_idx_2d = np.ix_(sub_idx, sub_idx)
count_2d[sub_idx_2d] += 1.0
count_1d[sub_idx] += 1.0
node_s[sub_idx] += node[0].cpu().numpy()
for i_logit, logit in enumerate(logit_s):
prob = self.active_fn(logit.float()) # calculate distogram
prob = prob.squeeze(0).permute(1,2,0).cpu().numpy()
prob_s[i_logit][sub_idx_2d] += prob
del logit_s, node
#
for i in range(4):
prob_s[i] = prob_s[i] / count_2d[:,:,None]
prob_in = np.concatenate(prob_s, axis=-1)
node_s = node_s / count_1d[:, None]
#
node_s = torch.tensor(node_s).to(self.device).unsqueeze(0)
seq = msa[:,0].to(self.device)
idx_pdb = idx_pdb.to(self.device)
prob_in = torch.tensor(prob_in).to(self.device).unsqueeze(0)
with torch.cuda.amp.autocast():
xyz, lddt = self.model(node_s, seq, idx_pdb, prob_s=prob_in, refine_only=True)
print (lddt.mean())
else:
msa = msa[:,:1000].to(self.device)
seq = msa[:,0]
idx_pdb = idx_pdb.to(self.device)
t1d = t1d[:,:10].to(self.device)
t2d = t2d[:,:10].to(self.device)
with torch.cuda.amp.autocast():
logit_s, _, xyz, lddt = self.model(msa, seq, idx_pdb, t1d=t1d, t2d=t2d)
print (lddt.mean())
prob_s = list()
for logit in logit_s:
prob = self.active_fn(logit.float()) # distogram
prob = prob.reshape(-1, L, L).permute(1,2,0).cpu().numpy()
prob_s.append(prob)
np.savez_compressed("%s.npz"%out_prefix, dist=prob_s[0].astype(np.float16), \
omega=prob_s[1].astype(np.float16),\
theta=prob_s[2].astype(np.float16),\
phi=prob_s[3].astype(np.float16))
# run TRFold
prob_trF = list()
for prob in prob_s:
prob = torch.tensor(prob).permute(2,0,1).to(self.device)
prob += 1e-8
prob = prob / torch.sum(prob, dim=0)[None]
prob_trF.append(prob)
xyz = xyz[0, :, 1]
TRF = TRFold(prob_trF, fold_params)
xyz = TRF.fold(xyz, batch=15, lr=0.1, nsteps=200)
print (xyz.shape, lddt[0].shape, seq[0].shape)
self.write_pdb(seq[0], xyz, Ls, Bfacts=lddt[0], prefix=out_prefix)
def write_pdb(self, seq, atoms, Ls, Bfacts=None, prefix=None):
chainIDs = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
L = len(seq)
filename = "%s.pdb"%prefix
ctr = 1
with open(filename, 'wt') as f:
if Bfacts == None:
Bfacts = np.zeros(L)
else:
Bfacts = torch.clamp( Bfacts, 0, 1)
for i,s in enumerate(seq):
if (len(atoms.shape)==2):
resNo = i+1
chain = "A"
for i_chain in range(len(Ls)-1,0,-1):
tot_res = sum(Ls[:i_chain])
if i+1 > tot_res:
chain = chainIDs[i_chain]
resNo = i+1 - tot_res
break
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, " CA ", util.num2aa[s],
chain, resNo, atoms[i,0], atoms[i,1], atoms[i,2],
1.0, Bfacts[i] ) )
ctr += 1
elif atoms.shape[1]==3:
resNo = i+1
chain = "A"
for i_chain in range(len(Ls)-1,0,-1):
tot_res = sum(Ls[:i_chain])
if i+1 > tot_res:
chain = chainIDs[i_chain]
resNo = i+1 - tot_res
break
for j,atm_j in enumerate((" N "," CA "," C ")):
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, atm_j, util.num2aa[s],
chain, resNo, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2],
1.0, Bfacts[i] ) )
ctr += 1
def get_args():
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("-m", dest="model_dir", default="%s/weights"%(script_dir),
help="Path to pre-trained network weights [%s/weights]"%script_dir)
parser.add_argument("-i", dest="a3m_fn", required=True,
help="Input multiple sequence alignments (in a3m format)")
parser.add_argument("-o", dest="out_prefix", required=True,
help="Prefix for output file. The output files will be [out_prefix].npz and [out_prefix].pdb")
parser.add_argument("-Ls", dest="Ls", required=True, nargs="+", type=int,
help="The length of the each subunit (e.g. 220 400)")
parser.add_argument("--templ_npz", default=None,
help='''npz file containing complex template information (xyz_t, t1d, t0d). If not provided, zero matrices will be given as templates
- xyz_t: N, CA, C coordinates of complex templates (T, L, 3, 3) For the unaligned region, it should be NaN
- t1d: 1-D features from HHsearch results (score, SS, probab column from atab file) (T, L, 3). For the unaligned region, it should be zeros
- t0d: 0-D features from HHsearch (Probability/100.0, Ideintities/100.0, Similarity fro hhr file) (T, 3)''')
parser.add_argument("--cpu", dest='use_cpu', default=False, action='store_true')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
if not os.path.exists("%s.npz"%args.out_prefix):
pred = Predictor(model_dir=args.model_dir, use_cpu=args.use_cpu)
pred.predict(args.a3m_fn, args.out_prefix, args.Ls, templ_npz=args.templ_npz)