318 lines
13 KiB
Python
318 lines
13 KiB
Python
import sys, os
|
|
import time
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils import data
|
|
from parsers import parse_a3m, read_templates
|
|
from RoseTTAFoldModel import RoseTTAFoldModule_e2e
|
|
import util
|
|
from collections import namedtuple
|
|
from ffindex import *
|
|
from kinematics import xyz_to_c6d, c6d_to_bins2, xyz_to_t2d
|
|
from trFold import TRFold
|
|
|
|
script_dir = '/'.join(os.path.dirname(os.path.realpath(__file__)).split('/')[:-1])
|
|
NBIN = [37, 37, 37, 19]
|
|
|
|
MODEL_PARAM ={
|
|
"n_module" : 8,
|
|
"n_module_str" : 4,
|
|
"n_module_ref" : 4,
|
|
"n_layer" : 1,
|
|
"d_msa" : 384 ,
|
|
"d_pair" : 288,
|
|
"d_templ" : 64,
|
|
"n_head_msa" : 12,
|
|
"n_head_pair" : 8,
|
|
"n_head_templ" : 4,
|
|
"d_hidden" : 64,
|
|
"r_ff" : 4,
|
|
"n_resblock" : 1,
|
|
"p_drop" : 0.1,
|
|
"use_templ" : True,
|
|
"performer_N_opts": {"nb_features": 64},
|
|
"performer_L_opts": {"nb_features": 64}
|
|
}
|
|
|
|
SE3_param = {
|
|
"num_layers" : 2,
|
|
"num_channels" : 16,
|
|
"num_degrees" : 2,
|
|
"l0_in_features": 32,
|
|
"l0_out_features": 8,
|
|
"l1_in_features": 3,
|
|
"l1_out_features": 3,
|
|
"num_edge_features": 32,
|
|
"div": 2,
|
|
"n_heads": 4
|
|
}
|
|
|
|
REF_param = {
|
|
"num_layers" : 3,
|
|
"num_channels" : 32,
|
|
"num_degrees" : 3,
|
|
"l0_in_features": 32,
|
|
"l0_out_features": 8,
|
|
"l1_in_features": 3,
|
|
"l1_out_features": 3,
|
|
"num_edge_features": 32,
|
|
"div": 4,
|
|
"n_heads": 4
|
|
}
|
|
MODEL_PARAM['SE3_param'] = SE3_param
|
|
MODEL_PARAM['REF_param'] = REF_param
|
|
|
|
# params for the folding protocol
|
|
fold_params = {
|
|
"SG7" : np.array([[[-2,3,6,7,6,3,-2]]])/21,
|
|
"SG9" : np.array([[[-21,14,39,54,59,54,39,14,-21]]])/231,
|
|
"DCUT" : 19.5,
|
|
"ALPHA" : 1.57,
|
|
|
|
# TODO: add Cb to the motif
|
|
"NCAC" : np.array([[-0.676, -1.294, 0. ],
|
|
[ 0. , 0. , 0. ],
|
|
[ 1.5 , -0.174, 0. ]], dtype=np.float32),
|
|
"CLASH" : 2.0,
|
|
"PCUT" : 0.5,
|
|
"DSTEP" : 0.5,
|
|
"ASTEP" : np.deg2rad(10.0),
|
|
"XYZRAD" : 7.5,
|
|
"WANG" : 0.1,
|
|
"WCST" : 0.1
|
|
}
|
|
|
|
fold_params["SG"] = fold_params["SG9"]
|
|
|
|
class Predictor():
|
|
def __init__(self, model_dir=None, use_cpu=False):
|
|
if model_dir == None:
|
|
self.model_dir = "%s/weights"%(script_dir)
|
|
else:
|
|
self.model_dir = model_dir
|
|
#
|
|
# define model name
|
|
self.model_name = "RoseTTAFold"
|
|
if torch.cuda.is_available() and (not use_cpu):
|
|
self.device = torch.device("cuda")
|
|
else:
|
|
self.device = torch.device("cpu")
|
|
self.active_fn = nn.Softmax(dim=1)
|
|
|
|
# define model & load model
|
|
self.model = RoseTTAFoldModule_e2e(**MODEL_PARAM).to(self.device)
|
|
could_load = self.load_model(self.model_name)
|
|
if not could_load:
|
|
print ("ERROR: failed to load model")
|
|
sys.exit()
|
|
|
|
def load_model(self, model_name, suffix='e2e'):
|
|
chk_fn = "%s/%s_%s.pt"%(self.model_dir, model_name, suffix)
|
|
if not os.path.exists(chk_fn):
|
|
return False
|
|
checkpoint = torch.load(chk_fn, map_location=self.device)
|
|
self.model.load_state_dict(checkpoint['model_state_dict'], strict=True)
|
|
return True
|
|
|
|
def predict(self, a3m_fn, out_prefix, Ls, templ_npz=None, window=1000, shift=100):
|
|
msa = parse_a3m(a3m_fn)
|
|
N, L = msa.shape
|
|
#
|
|
if templ_npz != None:
|
|
templ = np.load(templ_npz)
|
|
xyz_t = torch.from_numpy(templ["xyz_t"])
|
|
t1d = torch.from_numpy(templ["t1d"])
|
|
t0d = torch.from_numpy(templ["t0d"])
|
|
else:
|
|
xyz_t = torch.full((1, L, 3, 3), np.nan).float()
|
|
t1d = torch.zeros((1, L, 3)).float()
|
|
t0d = torch.zeros((1,3)).float()
|
|
|
|
self.model.eval()
|
|
with torch.no_grad():
|
|
#
|
|
msa = torch.tensor(msa).long().view(1, -1, L)
|
|
idx_pdb_orig = torch.arange(L).long().view(1, L)
|
|
idx_pdb = torch.arange(L).long().view(1, L)
|
|
L_prev = 0
|
|
for L_i in Ls[:-1]:
|
|
idx_pdb[:,L_prev+L_i:] += 500 # it was 200 originally.
|
|
L_prev += L_i
|
|
seq = msa[:,0]
|
|
#
|
|
# template features
|
|
xyz_t = xyz_t.float().unsqueeze(0)
|
|
t1d = t1d.float().unsqueeze(0)
|
|
t0d = t0d.float().unsqueeze(0)
|
|
t2d = xyz_to_t2d(xyz_t, t0d)
|
|
#
|
|
# do cropped prediction
|
|
if L > window*2:
|
|
prob_s = [np.zeros((L,L,NBIN[i]), dtype=np.float32) for i in range(4)]
|
|
count_1d = np.zeros((L,), dtype=np.float32)
|
|
count_2d = np.zeros((L,L), dtype=np.float32)
|
|
node_s = np.zeros((L,MODEL_PARAM['d_msa']), dtype=np.float32)
|
|
#
|
|
grids = np.arange(0, L-window+shift, shift)
|
|
ngrids = grids.shape[0]
|
|
print("ngrid: ", ngrids)
|
|
print("grids: ", grids)
|
|
print("windows: ", window)
|
|
|
|
for i in range(ngrids):
|
|
for j in range(i, ngrids):
|
|
start_1 = grids[i]
|
|
end_1 = min(grids[i]+window, L)
|
|
start_2 = grids[j]
|
|
end_2 = min(grids[j]+window, L)
|
|
sel = np.zeros((L)).astype(np.bool)
|
|
sel[start_1:end_1] = True
|
|
sel[start_2:end_2] = True
|
|
|
|
input_msa = msa[:,:,sel]
|
|
mask = torch.sum(input_msa==20, dim=-1) < 0.5*sel.sum() # remove too gappy sequences
|
|
input_msa = input_msa[mask].unsqueeze(0)
|
|
input_msa = input_msa[:,:1000].to(self.device)
|
|
input_idx = idx_pdb[:,sel].to(self.device)
|
|
input_idx_orig = idx_pdb_orig[:,sel]
|
|
input_seq = input_msa[:,0].to(self.device)
|
|
#
|
|
# Select template
|
|
input_t1d = t1d[:,:,sel].to(self.device) # (B, T, L, 3)
|
|
input_t2d = t2d[:,:,sel][:,:,:,sel].to(self.device)
|
|
#
|
|
print ("running crop: %d-%d/%d-%d"%(start_1, end_1, start_2, end_2), input_msa.shape)
|
|
with torch.cuda.amp.autocast():
|
|
logit_s, node, init_crds, pred_lddt = self.model(input_msa, input_seq, input_idx, t1d=input_t1d, t2d=input_t2d, return_raw=True)
|
|
#
|
|
# Not sure How can we merge init_crds.....
|
|
pred_lddt = torch.clamp(pred_lddt, 0.0, 1.0)
|
|
sub_idx = input_idx_orig[0]
|
|
sub_idx_2d = np.ix_(sub_idx, sub_idx)
|
|
count_2d[sub_idx_2d] += 1.0
|
|
count_1d[sub_idx] += 1.0
|
|
node_s[sub_idx] += node[0].cpu().numpy()
|
|
for i_logit, logit in enumerate(logit_s):
|
|
prob = self.active_fn(logit.float()) # calculate distogram
|
|
prob = prob.squeeze(0).permute(1,2,0).cpu().numpy()
|
|
prob_s[i_logit][sub_idx_2d] += prob
|
|
del logit_s, node
|
|
#
|
|
for i in range(4):
|
|
prob_s[i] = prob_s[i] / count_2d[:,:,None]
|
|
prob_in = np.concatenate(prob_s, axis=-1)
|
|
node_s = node_s / count_1d[:, None]
|
|
#
|
|
node_s = torch.tensor(node_s).to(self.device).unsqueeze(0)
|
|
seq = msa[:,0].to(self.device)
|
|
idx_pdb = idx_pdb.to(self.device)
|
|
prob_in = torch.tensor(prob_in).to(self.device).unsqueeze(0)
|
|
with torch.cuda.amp.autocast():
|
|
xyz, lddt = self.model(node_s, seq, idx_pdb, prob_s=prob_in, refine_only=True)
|
|
print (lddt.mean())
|
|
else:
|
|
msa = msa[:,:1000].to(self.device)
|
|
seq = msa[:,0]
|
|
idx_pdb = idx_pdb.to(self.device)
|
|
t1d = t1d[:,:10].to(self.device)
|
|
t2d = t2d[:,:10].to(self.device)
|
|
with torch.cuda.amp.autocast():
|
|
logit_s, _, xyz, lddt = self.model(msa, seq, idx_pdb, t1d=t1d, t2d=t2d)
|
|
print (lddt.mean())
|
|
prob_s = list()
|
|
for logit in logit_s:
|
|
prob = self.active_fn(logit.float()) # distogram
|
|
prob = prob.reshape(-1, L, L).permute(1,2,0).cpu().numpy()
|
|
prob_s.append(prob)
|
|
|
|
np.savez_compressed("%s.npz"%out_prefix, dist=prob_s[0].astype(np.float16), \
|
|
omega=prob_s[1].astype(np.float16),\
|
|
theta=prob_s[2].astype(np.float16),\
|
|
phi=prob_s[3].astype(np.float16))
|
|
|
|
# run TRFold
|
|
prob_trF = list()
|
|
for prob in prob_s:
|
|
prob = torch.tensor(prob).permute(2,0,1).to(self.device)
|
|
prob += 1e-8
|
|
prob = prob / torch.sum(prob, dim=0)[None]
|
|
prob_trF.append(prob)
|
|
xyz = xyz[0, :, 1]
|
|
TRF = TRFold(prob_trF, fold_params)
|
|
xyz = TRF.fold(xyz, batch=15, lr=0.1, nsteps=200)
|
|
print (xyz.shape, lddt[0].shape, seq[0].shape)
|
|
self.write_pdb(seq[0], xyz, Ls, Bfacts=lddt[0], prefix=out_prefix)
|
|
|
|
def write_pdb(self, seq, atoms, Ls, Bfacts=None, prefix=None):
|
|
chainIDs = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
|
L = len(seq)
|
|
filename = "%s.pdb"%prefix
|
|
ctr = 1
|
|
with open(filename, 'wt') as f:
|
|
if Bfacts == None:
|
|
Bfacts = np.zeros(L)
|
|
else:
|
|
Bfacts = torch.clamp( Bfacts, 0, 1)
|
|
|
|
for i,s in enumerate(seq):
|
|
if (len(atoms.shape)==2):
|
|
resNo = i+1
|
|
chain = "A"
|
|
for i_chain in range(len(Ls)-1,0,-1):
|
|
tot_res = sum(Ls[:i_chain])
|
|
if i+1 > tot_res:
|
|
chain = chainIDs[i_chain]
|
|
resNo = i+1 - tot_res
|
|
break
|
|
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
|
|
"ATOM", ctr, " CA ", util.num2aa[s],
|
|
chain, resNo, atoms[i,0], atoms[i,1], atoms[i,2],
|
|
1.0, Bfacts[i] ) )
|
|
ctr += 1
|
|
|
|
elif atoms.shape[1]==3:
|
|
resNo = i+1
|
|
chain = "A"
|
|
for i_chain in range(len(Ls)-1,0,-1):
|
|
tot_res = sum(Ls[:i_chain])
|
|
if i+1 > tot_res:
|
|
chain = chainIDs[i_chain]
|
|
resNo = i+1 - tot_res
|
|
break
|
|
for j,atm_j in enumerate((" N "," CA "," C ")):
|
|
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
|
|
"ATOM", ctr, atm_j, util.num2aa[s],
|
|
chain, resNo, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2],
|
|
1.0, Bfacts[i] ) )
|
|
ctr += 1
|
|
|
|
def get_args():
|
|
import argparse
|
|
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
|
|
parser.add_argument("-m", dest="model_dir", default="%s/weights"%(script_dir),
|
|
help="Path to pre-trained network weights [%s/weights]"%script_dir)
|
|
parser.add_argument("-i", dest="a3m_fn", required=True,
|
|
help="Input multiple sequence alignments (in a3m format)")
|
|
parser.add_argument("-o", dest="out_prefix", required=True,
|
|
help="Prefix for output file. The output files will be [out_prefix].npz and [out_prefix].pdb")
|
|
parser.add_argument("-Ls", dest="Ls", required=True, nargs="+", type=int,
|
|
help="The length of the each subunit (e.g. 220 400)")
|
|
parser.add_argument("--templ_npz", default=None,
|
|
help='''npz file containing complex template information (xyz_t, t1d, t0d). If not provided, zero matrices will be given as templates
|
|
- xyz_t: N, CA, C coordinates of complex templates (T, L, 3, 3) For the unaligned region, it should be NaN
|
|
- t1d: 1-D features from HHsearch results (score, SS, probab column from atab file) (T, L, 3). For the unaligned region, it should be zeros
|
|
- t0d: 0-D features from HHsearch (Probability/100.0, Ideintities/100.0, Similarity fro hhr file) (T, 3)''')
|
|
parser.add_argument("--cpu", dest='use_cpu', default=False, action='store_true')
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = get_args()
|
|
if not os.path.exists("%s.npz"%args.out_prefix):
|
|
pred = Predictor(model_dir=args.model_dir, use_cpu=args.use_cpu)
|
|
pred.predict(args.a3m_fn, args.out_prefix, args.Ls, templ_npz=args.templ_npz)
|