217 lines
6.7 KiB
Python
217 lines
6.7 KiB
Python
import numpy as np
|
|
import torch
|
|
|
|
PARAMS = {
|
|
"DMIN" : 2.0,
|
|
"DMAX" : 20.0,
|
|
"DBINS" : 36,
|
|
"ABINS" : 36,
|
|
}
|
|
|
|
# ============================================================
|
|
def get_pair_dist(a, b):
|
|
"""calculate pair distances between two sets of points
|
|
|
|
Parameters
|
|
----------
|
|
a,b : pytorch tensors of shape [batch,nres,3]
|
|
store Cartesian coordinates of two sets of atoms
|
|
Returns
|
|
-------
|
|
dist : pytorch tensor of shape [batch,nres,nres]
|
|
stores paitwise distances between atoms in a and b
|
|
"""
|
|
|
|
dist = torch.cdist(a, b, p=2)
|
|
return dist
|
|
|
|
# ============================================================
|
|
def get_ang(a, b, c):
|
|
"""calculate planar angles for all consecutive triples (a[i],b[i],c[i])
|
|
from Cartesian coordinates of three sets of atoms a,b,c
|
|
|
|
Parameters
|
|
----------
|
|
a,b,c : pytorch tensors of shape [batch,nres,3]
|
|
store Cartesian coordinates of three sets of atoms
|
|
Returns
|
|
-------
|
|
ang : pytorch tensor of shape [batch,nres]
|
|
stores resulting planar angles
|
|
"""
|
|
v = a - b
|
|
w = c - b
|
|
v /= torch.norm(v, dim=-1, keepdim=True)
|
|
w /= torch.norm(w, dim=-1, keepdim=True)
|
|
vw = torch.sum(v*w, dim=-1)
|
|
|
|
return torch.acos(vw)
|
|
|
|
# ============================================================
|
|
def get_dih(a, b, c, d):
|
|
"""calculate dihedral angles for all consecutive quadruples (a[i],b[i],c[i],d[i])
|
|
given Cartesian coordinates of four sets of atoms a,b,c,d
|
|
|
|
Parameters
|
|
----------
|
|
a,b,c,d : pytorch tensors of shape [batch,nres,3]
|
|
store Cartesian coordinates of four sets of atoms
|
|
Returns
|
|
-------
|
|
dih : pytorch tensor of shape [batch,nres]
|
|
stores resulting dihedrals
|
|
"""
|
|
b0 = a - b
|
|
b1 = c - b
|
|
b2 = d - c
|
|
|
|
b1 /= torch.norm(b1, dim=-1, keepdim=True)
|
|
|
|
v = b0 - torch.sum(b0*b1, dim=-1, keepdim=True)*b1
|
|
w = b2 - torch.sum(b2*b1, dim=-1, keepdim=True)*b1
|
|
|
|
x = torch.sum(v*w, dim=-1)
|
|
y = torch.sum(torch.cross(b1,v,dim=-1)*w, dim=-1)
|
|
|
|
return torch.atan2(y, x)
|
|
|
|
|
|
# ============================================================
|
|
def xyz_to_c6d(xyz, params=PARAMS):
|
|
"""convert cartesian coordinates into 2d distance
|
|
and orientation maps
|
|
|
|
Parameters
|
|
----------
|
|
xyz : pytorch tensor of shape [batch,nres,3,3]
|
|
stores Cartesian coordinates of backbone N,Ca,C atoms
|
|
Returns
|
|
-------
|
|
c6d : pytorch tensor of shape [batch,nres,nres,4]
|
|
stores stacked dist,omega,theta,phi 2D maps
|
|
"""
|
|
|
|
batch = xyz.shape[0]
|
|
nres = xyz.shape[1]
|
|
|
|
# three anchor atoms
|
|
N = xyz[:,:,0]
|
|
Ca = xyz[:,:,1]
|
|
C = xyz[:,:,2]
|
|
|
|
# recreate Cb given N,Ca,C
|
|
b = Ca - N
|
|
c = C - Ca
|
|
a = torch.cross(b, c, dim=-1)
|
|
Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
|
|
|
|
# 6d coordinates order: (dist,omega,theta,phi)
|
|
c6d = torch.zeros([batch,nres,nres,4],dtype=xyz.dtype,device=xyz.device)
|
|
|
|
dist = get_pair_dist(Cb,Cb)
|
|
dist[torch.isnan(dist)] = 999.9
|
|
c6d[...,0] = dist + 999.9*torch.eye(nres,device=xyz.device)[None,...]
|
|
b,i,j = torch.where(c6d[...,0]<params['DMAX'])
|
|
|
|
c6d[b,i,j,torch.full_like(b,1)] = get_dih(Ca[b,i], Cb[b,i], Cb[b,j], Ca[b,j])
|
|
c6d[b,i,j,torch.full_like(b,2)] = get_dih(N[b,i], Ca[b,i], Cb[b,i], Cb[b,j])
|
|
c6d[b,i,j,torch.full_like(b,3)] = get_ang(Ca[b,i], Cb[b,i], Cb[b,j])
|
|
|
|
# fix long-range distances
|
|
c6d[...,0][c6d[...,0]>=params['DMAX']] = 999.9
|
|
|
|
mask = torch.zeros((batch, nres,nres), dtype=xyz.dtype, device=xyz.device)
|
|
mask[b,i,j] = 1.0
|
|
return c6d, mask
|
|
|
|
def xyz_to_t2d(xyz_t, t0d, params=PARAMS):
|
|
"""convert template cartesian coordinates into 2d distance
|
|
and orientation maps
|
|
|
|
Parameters
|
|
----------
|
|
xyz_t : pytorch tensor of shape [batch,templ,nres,3,3]
|
|
stores Cartesian coordinates of template backbone N,Ca,C atoms
|
|
t0d: 0-D template features (HHprob, seqID, similarity) [batch, templ, 3]
|
|
|
|
Returns
|
|
-------
|
|
t2d : pytorch tensor of shape [batch,nres,nres,1+6+3]
|
|
stores stacked dist,omega,theta,phi 2D maps
|
|
"""
|
|
B, T, L = xyz_t.shape[:3]
|
|
c6d, mask = xyz_to_c6d(xyz_t.view(B*T,L,3,3), params=params)
|
|
c6d = c6d.view(B, T, L, L, 4)
|
|
mask = mask.view(B, T, L, L, 1)
|
|
#
|
|
dist = c6d[...,:1]*mask / params['DMAX'] # from 0 to 1 # (B, T, L, L, 1)
|
|
dist = torch.clamp(dist, 0.0, 1.0)
|
|
orien = torch.cat((torch.sin(c6d[...,1:]), torch.cos(c6d[...,1:])), dim=-1)*mask # (B, T, L, L, 6)
|
|
t0d = t0d.unsqueeze(2).unsqueeze(3).expand(-1, -1, L, L, -1)
|
|
#
|
|
t2d = torch.cat((dist, orien, t0d), dim=-1)
|
|
t2d[torch.isnan(t2d)] = 0.0
|
|
return t2d
|
|
|
|
# ============================================================
|
|
def c6d_to_bins(c6d,params=PARAMS):
|
|
"""bin 2d distance and orientation maps
|
|
"""
|
|
|
|
dstep = (params['DMAX'] - params['DMIN']) / params['DBINS']
|
|
astep = 2.0*np.pi / params['ABINS']
|
|
|
|
dbins = torch.linspace(params['DMIN']+dstep, params['DMAX'], params['DBINS'],dtype=c6d.dtype,device=c6d.device)
|
|
ab360 = torch.linspace(-np.pi+astep, np.pi, params['ABINS'],dtype=c6d.dtype,device=c6d.device)
|
|
ab180 = torch.linspace(astep, np.pi, params['ABINS']//2,dtype=c6d.dtype,device=c6d.device)
|
|
|
|
db = torch.bucketize(c6d[...,0].contiguous(),dbins)
|
|
ob = torch.bucketize(c6d[...,1].contiguous(),ab360)
|
|
tb = torch.bucketize(c6d[...,2].contiguous(),ab360)
|
|
pb = torch.bucketize(c6d[...,3].contiguous(),ab180)
|
|
|
|
ob[db==params['DBINS']] = params['ABINS']
|
|
tb[db==params['DBINS']] = params['ABINS']
|
|
pb[db==params['DBINS']] = params['ABINS']//2
|
|
|
|
return torch.stack([db,ob,tb,pb],axis=-1).to(torch.uint8)
|
|
|
|
|
|
# ============================================================
|
|
def dist_to_bins(dist,params=PARAMS):
|
|
"""bin 2d distance maps
|
|
"""
|
|
|
|
dstep = (params['DMAX'] - params['DMIN']) / params['DBINS']
|
|
db = torch.round((dist-params['DMIN']-dstep/2)/dstep)
|
|
|
|
db[db<0] = 0
|
|
db[db>params['DBINS']] = params['DBINS']
|
|
|
|
return db.long()
|
|
|
|
|
|
# ============================================================
|
|
def c6d_to_bins2(c6d,params=PARAMS):
|
|
"""bin 2d distance and orientation maps
|
|
"""
|
|
|
|
dstep = (params['DMAX'] - params['DMIN']) / params['DBINS']
|
|
astep = 2.0*np.pi / params['ABINS']
|
|
|
|
db = torch.round((c6d[...,0]-params['DMIN']-dstep/2)/dstep)
|
|
ob = torch.round((c6d[...,1]+np.pi-astep/2)/astep)
|
|
tb = torch.round((c6d[...,2]+np.pi-astep/2)/astep)
|
|
pb = torch.round((c6d[...,3]-astep/2)/astep)
|
|
|
|
# put all d<dmin into one bin
|
|
db[db<0] = 0
|
|
|
|
# synchronize no-contact bins
|
|
db[db>params['DBINS']] = params['DBINS']
|
|
ob[db==params['DBINS']] = params['ABINS']
|
|
tb[db==params['DBINS']] = params['ABINS']
|
|
pb[db==params['DBINS']] = params['ABINS']//2
|
|
|
|
return torch.stack([db,ob,tb,pb],axis=-1).long()
|