[RoseTTAFold] Initial release
This commit is contained in:
parent
26d8955cc5
commit
0db746b4ab
73
DGLPyTorch/DrugDiscovery/RoseTTAFold/Dockerfile
Normal file
73
DGLPyTorch/DrugDiscovery/RoseTTAFold/Dockerfile
Normal file
|
@ -0,0 +1,73 @@
|
|||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.09-py3
|
||||
|
||||
FROM ${FROM_IMAGE_NAME} AS dgl_builder
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y git build-essential python3-dev make cmake \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
WORKDIR /dgl
|
||||
RUN git clone --branch v0.7.0 --recurse-submodules --depth 1 https://github.com/dmlc/dgl.git .
|
||||
RUN sed -i 's/"35 50 60 70"/"60 70 80"/g' cmake/modules/CUDA.cmake
|
||||
WORKDIR build
|
||||
RUN cmake -DUSE_CUDA=ON -DUSE_FP16=ON ..
|
||||
RUN make -j8
|
||||
|
||||
|
||||
FROM ${FROM_IMAGE_NAME}
|
||||
|
||||
# VERY IMPORTANT, DO NOT REMOVE:
|
||||
ENV FORCE_CUDA 1
|
||||
RUN pip install -v torch-geometric
|
||||
RUN pip install -v torch-scatter
|
||||
RUN pip install -v torch-sparse
|
||||
RUN pip install -v torch-cluster
|
||||
RUN pip install -v torch-spline-conv
|
||||
|
||||
|
||||
# copy built DGL and install it
|
||||
COPY --from=dgl_builder /dgl ./dgl
|
||||
RUN cd dgl/python && python setup.py install && cd ../.. && rm -rf dgl
|
||||
ENV DGLBACKEND=pytorch
|
||||
#RUN pip install dgl-cu111 -f https://data.dgl.ai/wheels/repo.html
|
||||
|
||||
|
||||
# HH-Suite
|
||||
RUN git clone https://github.com/soedinglab/hh-suite.git && \
|
||||
mkdir -p hh-suite/build
|
||||
WORKDIR hh-suite/build
|
||||
RUN cmake .. && \
|
||||
make && \
|
||||
make install
|
||||
|
||||
|
||||
# PSIPRED
|
||||
WORKDIR /workspace
|
||||
RUN wget http://wwwuser.gwdg.de/~compbiol/data/csblast/releases/csblast-2.2.3_linux64.tar.gz -O csblast-2.2.3.tar.gz && \
|
||||
mkdir -p csblast-2.2.3 && \
|
||||
tar xf csblast-2.2.3.tar.gz -C csblast-2.2.3 --strip-components=1 && \
|
||||
rm csblast-2.2.3.tar.gz
|
||||
|
||||
RUN wget https://ftp.ncbi.nlm.nih.gov/blast/executables/legacy.NOTSUPPORTED/2.2.26/blast-2.2.26-x64-linux.tar.gz && \
|
||||
tar xf blast-2.2.26-x64-linux.tar.gz && \
|
||||
rm blast-2.2.26-x64-linux.tar.gz
|
||||
|
||||
RUN wget http://bioinfadmin.cs.ucl.ac.uk/downloads/psipred/psipred.4.02.tar.gz && \
|
||||
tar xf psipred.4.02.tar.gz && \
|
||||
rm psipred.4.02.tar.gz
|
||||
|
||||
|
||||
ADD . /workspace/rf
|
||||
WORKDIR /workspace/rf
|
||||
|
||||
RUN wget https://openstructure.org/static/lddt-linux.zip -O lddt.zip && unzip -d lddt -j lddt.zip
|
||||
|
||||
RUN pip install --upgrade pip
|
||||
RUN pip install -r requirements.txt
|
21
DGLPyTorch/DrugDiscovery/RoseTTAFold/LICENSE
Normal file
21
DGLPyTorch/DrugDiscovery/RoseTTAFold/LICENSE
Normal file
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2021 RosettaCommons
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
94
DGLPyTorch/DrugDiscovery/RoseTTAFold/README-ROSETTAFOLD.md
Normal file
94
DGLPyTorch/DrugDiscovery/RoseTTAFold/README-ROSETTAFOLD.md
Normal file
|
@ -0,0 +1,94 @@
|
|||
# *RoseTTAFold*
|
||||
This package contains deep learning models and related scripts to run RoseTTAFold.
|
||||
This repository is the official implementation of RoseTTAFold: Accurate prediction of protein structures and interactions using a 3-track network.
|
||||
|
||||
## Installation
|
||||
|
||||
1. Clone the package
|
||||
```
|
||||
git clone https://github.com/RosettaCommons/RoseTTAFold.git
|
||||
cd RoseTTAFold
|
||||
```
|
||||
|
||||
2. Create conda environment using `RoseTTAFold-linux.yml` file and `folding-linux.yml` file. The latter is required to run a pyrosetta version only (run_pyrosetta_ver.sh).
|
||||
```
|
||||
# create conda environment for RoseTTAFold
|
||||
# If your NVIDIA driver compatible with cuda11
|
||||
conda env create -f RoseTTAFold-linux.yml
|
||||
# If not (but compatible with cuda10)
|
||||
conda env create -f RoseTTAFold-linux-cu101.yml
|
||||
|
||||
# create conda environment for pyRosetta folding & running DeepAccNet
|
||||
conda env create -f folding-linux.yml
|
||||
```
|
||||
|
||||
3. Download network weights (under Rosetta-DL Software license -- please see below)
|
||||
While the code is licensed under the MIT License, the trained weights and data for RoseTTAFold are made available for non-commercial use only under the terms of the Rosetta-DL Software license. You can find details at https://files.ipd.uw.edu/pub/RoseTTAFold/Rosetta-DL_LICENSE.txt
|
||||
|
||||
```
|
||||
wget https://files.ipd.uw.edu/pub/RoseTTAFold/weights.tar.gz
|
||||
tar xfz weights.tar.gz
|
||||
```
|
||||
|
||||
4. Download and install third-party software.
|
||||
```
|
||||
./install_dependencies.sh
|
||||
```
|
||||
|
||||
5. Download sequence and structure databases
|
||||
```
|
||||
# uniref30 [46G]
|
||||
wget http://wwwuser.gwdg.de/~compbiol/uniclust/2020_06/UniRef30_2020_06_hhsuite.tar.gz
|
||||
mkdir -p UniRef30_2020_06
|
||||
tar xfz UniRef30_2020_06_hhsuite.tar.gz -C ./UniRef30_2020_06
|
||||
|
||||
# BFD [272G]
|
||||
wget https://bfd.mmseqs.com/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz
|
||||
mkdir -p bfd
|
||||
tar xfz bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz -C ./bfd
|
||||
|
||||
# structure templates (including *_a3m.ffdata, *_a3m.ffindex) [over 100G]
|
||||
wget https://files.ipd.uw.edu/pub/RoseTTAFold/pdb100_2021Mar03.tar.gz
|
||||
tar xfz pdb100_2021Mar03.tar.gz
|
||||
# for CASP14 benchmarks, we used this one: https://files.ipd.uw.edu/pub/RoseTTAFold/pdb100_2020Mar11.tar.gz
|
||||
```
|
||||
|
||||
6. Obtain a [PyRosetta licence](https://els2.comotion.uw.edu/product/pyrosetta) and install the package in the newly created `folding` conda environment ([link](http://www.pyrosetta.org/downloads)).
|
||||
|
||||
## Usage
|
||||
|
||||
```
|
||||
# For monomer structure prediction
|
||||
cd example
|
||||
../run_[pyrosetta, e2e]_ver.sh input.fa .
|
||||
|
||||
# For complex modeling
|
||||
# please see README file under example/complex_modeling/README for details.
|
||||
python network/predict_complex.py -i paired.a3m -o complex -Ls 218 310
|
||||
```
|
||||
|
||||
## Expected outputs
|
||||
For the pyrosetta version, user will get five final models having estimated CA rms error at the B-factor column (model/model_[1-5].crderr.pdb).
|
||||
For the end-to-end version, there will be a single PDB output having estimated residue-wise CA-lddt at the B-factor column (t000_.e2e.pdb).
|
||||
|
||||
## FAQ
|
||||
1. Segmentation fault while running hhblits/hhsearch
|
||||
For easy install, we used a statically compiled version of hhsuite (installed through conda). Currently, we're not sure what exactly causes segmentation fault error in some cases, but we found that it might be resolved if you compile hhsuite from source and use this compiled version instead of conda version. For installation of hhsuite, please see [here](https://github.com/soedinglab/hh-suite).
|
||||
|
||||
2. Submitting jobs to computing nodes
|
||||
The modeling pipeline provided here (run_pyrosetta_ver.sh/run_e2e_ver.sh) is a kind of guidelines to show how RoseTTAFold works. For more efficient use of computing resources, you might want to modify the provided bash script to submit separate jobs with proper dependencies for each of steps (more cpus/memory for hhblits/hhsearch, using gpus only for running the networks, etc).
|
||||
|
||||
## Links:
|
||||
|
||||
* [Robetta server](https://robetta.bakerlab.org/) (RoseTTAFold option)
|
||||
* [RoseTTAFold models for CASP14 targets](https://files.ipd.uw.edu/pub/RoseTTAFold/casp14_models.tar.gz) [input MSA and hhsearch files are included]
|
||||
|
||||
## Credit to performer-pytorch and SE(3)-Transformer codes
|
||||
The code in the network/performer_pytorch.py is strongly based on [this repo](https://github.com/lucidrains/performer-pytorch) which is pytorch implementation of [Performer architecture](https://arxiv.org/abs/2009.14794).
|
||||
The codes in network/equivariant_attention is from the original SE(3)-Transformer [repo](https://github.com/FabianFuchsML/se3-transformer-public) which accompanies [the paper](https://arxiv.org/abs/2006.10503) 'SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks' by Fabian et al.
|
||||
|
||||
|
||||
## References
|
||||
|
||||
M Baek, et al., Accurate prediction of protein structures and interactions using a 3-track network, bioRxiv (2021). [link](https://www.biorxiv.org/content/10.1101/2021.06.14.448402v1)
|
||||
|
138
DGLPyTorch/DrugDiscovery/RoseTTAFold/README.md
Normal file
138
DGLPyTorch/DrugDiscovery/RoseTTAFold/README.md
Normal file
|
@ -0,0 +1,138 @@
|
|||
# RoseTTAFold for PyTorch
|
||||
|
||||
This repository provides a script to run inference using the RoseTTAFold model. The content of this repository is tested and maintained by NVIDIA.
|
||||
|
||||
## Table Of Contents
|
||||
|
||||
- [Model overview](#model-overview)
|
||||
* [Model architecture](#model-architecture)
|
||||
- [Setup](#setup)
|
||||
* [Requirements](#requirements)
|
||||
- [Quick Start Guide](#quick-start-guide)
|
||||
- [Release notes](#release-notes)
|
||||
* [Changelog](#changelog)
|
||||
* [Known issues](#known-issues)
|
||||
|
||||
|
||||
|
||||
## Model overview
|
||||
|
||||
The RoseTTAFold is a model designed to provide accurate protein structure from its amino acid sequence. This model is
|
||||
based on [Accurate prediction of protein structures and interactions using a 3-track network](https://www.biorxiv.org/content/10.1101/2021.06.14.448402v1) by Minkyung Baek et al.
|
||||
|
||||
This implementation is a dockerized version of the official [RoseTTAFold repository](https://github.com/RosettaCommons/RoseTTAFold/).
|
||||
Here you can find the [original RoseTTAFold guide](README-ROSETTAFOLD.md).
|
||||
|
||||
### Model architecture
|
||||
|
||||
The RoseTTAFold model is based on a 3-track architecture fusing 1D, 2D, and 3D information about the protein structure.
|
||||
All information is exchanged between tracks to learn the sequence and coordinate patterns at the same time. The final prediction
|
||||
is refined using an SE(3)-Transformer.
|
||||
|
||||
<img src="images/NetworkArchitecture.jpg" width="900"/>
|
||||
|
||||
*Figure 1: The RoseTTAFold architecture. Image comes from the [original paper](https://www.biorxiv.org/content/10.1101/2021.06.14.448402v1).*
|
||||
|
||||
## Setup
|
||||
|
||||
The following section lists the requirements that you need to meet in order to run inference using the RoseTTAFold model.
|
||||
|
||||
### Requirements
|
||||
|
||||
This repository contains a Dockerfile that extends the PyTorch NGC container and encapsulates necessary dependencies. Aside from these dependencies, ensure you have the following components:
|
||||
- [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
|
||||
- PyTorch 21.09-py3 NGC container
|
||||
- Supported GPUs:
|
||||
- [NVIDIA Volta architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
|
||||
- [NVIDIA Turing architecture](https://www.nvidia.com/en-us/design-visualization/technologies/turing-architecture/)
|
||||
- [NVIDIA Ampere architecture](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/)
|
||||
|
||||
For more information about how to get started with NGC containers, refer to the following sections from the NVIDIA GPU Cloud Documentation and the Deep Learning Documentation:
|
||||
- [Getting Started Using NVIDIA GPU Cloud](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html)
|
||||
- [Accessing And Pulling From The NGC Container Registry](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#accessing_registry)
|
||||
- [Running PyTorch](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/running.html#running)
|
||||
|
||||
For those unable to use the PyTorch NGC container, to set up the required environment or create your own container, refer to the versioned [NVIDIA Container Support Matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html).
|
||||
|
||||
In addition, 1 TB of disk space is required to unpack the required databases.
|
||||
|
||||
## Quick Start Guide
|
||||
|
||||
To run inference using the RoseTTAFold model, perform the following steps using the default parameters.
|
||||
|
||||
1. Clone the repository.
|
||||
```
|
||||
git clone https://github.com/NVIDIA/DeepLearningExamples
|
||||
cd DeepLearningExamples/DGLPyTorch/
|
||||
```
|
||||
|
||||
2. Download the pre-trained weights and databases needed for inference.
|
||||
The following command downloads the pre-trained weights and two databases needed to create derived features to the input to the model.
|
||||
The script will download the `UniRef30` (~50 GB) and `pdb100_2021Mar03` (~115 GB) databases, which might take a considerable amount
|
||||
of time. Additionally, unpacking those databases requires approximately 1 TB of free disk space.
|
||||
|
||||
By default, the data will be downloaded to `./weights` and `./databases` folders in the current directory.
|
||||
```
|
||||
bash scripts/download_databases.sh
|
||||
```
|
||||
If you would like to specify the download location you can pass the following parameters
|
||||
```
|
||||
bash scripts/download_databases.sh PATH-TO-WEIGHTS PATH-TO-DATABASES
|
||||
```
|
||||
|
||||
3. Build the RoseTTAFold PyTorch NGC container. This step builds the PyTorch dependencies on your machine and can take between 30 minutes and 1 hour to complete.
|
||||
```
|
||||
docker build -t rosettafold .
|
||||
```
|
||||
|
||||
4. Start an interactive session in the NGC container to run inference.
|
||||
|
||||
The following command launches the container and mount the `PATH-TO-WEIGHTS` directory as a volume to the `/weights` directory in the container, the `PATH-TO-DATABASES` directory as a volume to the `/databases` directory in the container, and `./results` directory to the `/results` directory in the container.
|
||||
```
|
||||
mkdir data results
|
||||
docker run --ipc=host -it --rm --runtime=nvidia -p6006:6006 -v PATH-TO-WEIGHTS:/weights -v PATH-TO-DATABASES:/databases -v ${PWD}/results:/results rosettafold:latest /bin/bash
|
||||
```
|
||||
|
||||
5. Start inference/predictions.
|
||||
|
||||
To run inference you have to prepare a FASTA file and pass a path to it or pass a sequence directly.
|
||||
```
|
||||
python run_inference_pipeline.py [Sequence]
|
||||
```
|
||||
There is an example FASTA file at `example/input.fa` for you to try. Running the inference pipeline consists of four steps:
|
||||
1. Preparing the Multiple Sequence Alignments (MSAs)
|
||||
2. Preparing the secondary structures
|
||||
3. Preparing the templates
|
||||
4. Iteratively refining the prediction
|
||||
|
||||
The first three steps can take between a couple of minutes and an hour, depending on the sequence.
|
||||
The output will be stored at the `/results` directory as an `output.e2e.pdb` file
|
||||
|
||||
6. Start Jupyter Notebook to run inference interactively.
|
||||
|
||||
To launch the application, copy the Notebook to the root folder.
|
||||
```
|
||||
cp notebooks/run_inference.ipynb .
|
||||
|
||||
```
|
||||
To start Jupyter Notebook, run:
|
||||
```
|
||||
jupyter notebook run_inference.ipynb
|
||||
```
|
||||
|
||||
For more information about Jupyter Notebook, refer to the Jupyter Notebook documentation.
|
||||
|
||||
|
||||
## Release notes
|
||||
|
||||
### Changelog
|
||||
|
||||
October 2021
|
||||
- Initial release
|
||||
|
||||
### Known issues
|
||||
|
||||
There are no known issues with this model.
|
||||
|
||||
|
||||
|
107
DGLPyTorch/DrugDiscovery/RoseTTAFold/RoseTTAFold-linux-cu101.yml
Normal file
107
DGLPyTorch/DrugDiscovery/RoseTTAFold/RoseTTAFold-linux-cu101.yml
Normal file
|
@ -0,0 +1,107 @@
|
|||
name: RoseTTAFold
|
||||
channels:
|
||||
- rusty1s
|
||||
- pytorch
|
||||
- nvidia
|
||||
- conda-forge
|
||||
- defaults
|
||||
- bioconda
|
||||
- biocore
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=main
|
||||
- _openmp_mutex=4.5=1_gnu
|
||||
- biopython=1.78=py38h497a2fe_2
|
||||
- blas=1.0=mkl
|
||||
- hhsuite
|
||||
- blast-legacy=2.2.26=2
|
||||
- brotlipy=0.7.0=py38h497a2fe_1001
|
||||
- bzip2=1.0.8=h7b6447c_0
|
||||
- ca-certificates=2021.7.5=h06a4308_1
|
||||
- certifi=2021.5.30=py38h06a4308_0
|
||||
- cffi=1.14.5=py38ha65f79e_0
|
||||
- chardet=4.0.0=py38h578d9bd_1
|
||||
- cryptography=3.4.7=py38ha5dfef3_0
|
||||
- cudatoolkit=10.2.89=h8f6ccaa_8
|
||||
- ffmpeg=4.3=hf484d3e_0
|
||||
- freetype=2.10.4=h5ab3b9f_0
|
||||
- gmp=6.2.1=h2531618_2
|
||||
- gnutls=3.6.15=he1e5248_0
|
||||
- googledrivedownloader=0.4=pyhd3deb0d_1
|
||||
- idna=2.10=pyh9f0ad1d_0
|
||||
- intel-openmp=2021.2.0=h06a4308_610
|
||||
- jinja2=3.0.1=pyhd8ed1ab_0
|
||||
- joblib=1.0.1=pyhd8ed1ab_0
|
||||
- jpeg=9b=h024ee3a_2
|
||||
- lame=3.100=h7b6447c_0
|
||||
- lcms2=2.12=h3be6417_0
|
||||
- ld_impl_linux-64=2.35.1=h7274673_9
|
||||
- libffi=3.3=he6710b0_2
|
||||
- libgcc-ng=9.3.0=h5101ec6_17
|
||||
- libgfortran-ng=7.5.0=h14aa051_19
|
||||
- libgfortran4=7.5.0=h14aa051_19
|
||||
- libgomp=9.3.0=h5101ec6_17
|
||||
- libiconv=1.15=h63c8f33_5
|
||||
- libidn2=2.3.1=h27cfd23_0
|
||||
- libpng=1.6.37=hbc83047_0
|
||||
- libstdcxx-ng=9.3.0=hd4cf53a_17
|
||||
- libtasn1=4.16.0=h27cfd23_0
|
||||
- libtiff=4.2.0=h85742a9_0
|
||||
- libunistring=0.9.10=h27cfd23_0
|
||||
- libuv=1.40.0=h7b6447c_0
|
||||
- libwebp-base=1.2.0=h27cfd23_0
|
||||
- lz4-c=1.9.3=h2531618_0
|
||||
- markupsafe=2.0.1=py38h497a2fe_0
|
||||
- mkl=2021.2.0=h06a4308_296
|
||||
- mkl-service=2.3.0=py38h27cfd23_1
|
||||
- mkl_fft=1.3.0=py38h42c9631_2
|
||||
- mkl_random=1.2.1=py38ha9443f7_2
|
||||
- ncurses=6.2=he6710b0_1
|
||||
- nettle=3.7.3=hbbd107a_1
|
||||
- networkx=2.5=py_0
|
||||
- ninja=1.10.2=hff7bd54_1
|
||||
- numpy=1.20.2=py38h2d18471_0
|
||||
- numpy-base=1.20.2=py38hfae3a4d_0
|
||||
- olefile=0.46=py_0
|
||||
- openh264=2.1.0=hd408876_0
|
||||
- openssl=1.1.1k=h27cfd23_0
|
||||
- packaging=20.9=pyhd3eb1b0_0
|
||||
- pandas=1.2.5=py38h1abd341_0
|
||||
- pillow=8.2.0=py38he98fc37_0
|
||||
- pip=21.1.3=py38h06a4308_0
|
||||
- psipred=4.01=1
|
||||
- pycparser=2.20=pyh9f0ad1d_2
|
||||
- pyopenssl=20.0.1=pyhd8ed1ab_0
|
||||
- pyparsing=2.4.7=pyh9f0ad1d_0
|
||||
- pysocks=1.7.1=py38h578d9bd_3
|
||||
- python=3.8.10=h12debd9_8
|
||||
- python-dateutil=2.8.1=py_0
|
||||
- python-louvain=0.15=pyhd3deb0d_0
|
||||
- python_abi=3.8=2_cp38
|
||||
- pytorch=1.8.1=py3.8_cuda10.2_cudnn7.6.5_0
|
||||
- pytorch-cluster=1.5.9=py38_torch_1.8.0_cu102
|
||||
- pytorch-geometric=1.7.2=py38_torch_1.8.0_cu102
|
||||
- pytorch-scatter=2.0.7=py38_torch_1.8.0_cu102
|
||||
- pytorch-sparse=0.6.10=py38_torch_1.8.0_cu102
|
||||
- pytorch-spline-conv=1.2.1=py38_torch_1.8.0_cu102
|
||||
- pytz=2021.1=pyhd8ed1ab_0
|
||||
- readline=8.1=h27cfd23_0
|
||||
- requests=2.25.1=pyhd3deb0d_0
|
||||
- scikit-learn=0.24.2=py38ha9443f7_0
|
||||
- setuptools=52.0.0=py38h06a4308_0
|
||||
- six=1.16.0=pyhd3eb1b0_0
|
||||
- sqlite=3.36.0=hc218d9a_0
|
||||
- threadpoolctl=2.1.0=pyh5ca1d4c_0
|
||||
- tk=8.6.10=hbc83047_0
|
||||
- torchvision=0.9.1=py38_cu102
|
||||
- tqdm=4.61.1=pyhd8ed1ab_0
|
||||
- typing_extensions=3.10.0.0=pyh06a4308_0
|
||||
- urllib3=1.26.6=pyhd8ed1ab_0
|
||||
- wheel=0.36.2=pyhd3eb1b0_0
|
||||
- xz=5.2.5=h7b6447c_0
|
||||
- zlib=1.2.11=h7b6447c_3
|
||||
- zstd=1.4.9=haebb681_0
|
||||
- pip:
|
||||
- decorator==4.4.2
|
||||
- dgl-cu102==0.6.1
|
||||
- lie-learn==0.0.1.post1
|
||||
- scipy==1.7.0
|
108
DGLPyTorch/DrugDiscovery/RoseTTAFold/RoseTTAFold-linux.yml
Normal file
108
DGLPyTorch/DrugDiscovery/RoseTTAFold/RoseTTAFold-linux.yml
Normal file
|
@ -0,0 +1,108 @@
|
|||
name: RoseTTAFold
|
||||
channels:
|
||||
- rusty1s
|
||||
- pytorch
|
||||
- nvidia
|
||||
- conda-forge
|
||||
- defaults
|
||||
- bioconda
|
||||
- biocore
|
||||
dependencies:
|
||||
- biopython=1.78
|
||||
- biocore::blast-legacy=2.2.26
|
||||
- hhsuite
|
||||
- psipred=4.01
|
||||
- _libgcc_mutex=0.1=main
|
||||
- _openmp_mutex=4.5=1_gnu
|
||||
- blas=1.0=mkl
|
||||
- brotlipy=0.7.0=py38h497a2fe_1001
|
||||
- bzip2=1.0.8=h7b6447c_0
|
||||
- ca-certificates=2021.5.25=h06a4308_1
|
||||
- certifi=2021.5.30=py38h06a4308_0
|
||||
- cffi=1.14.5=py38ha65f79e_0
|
||||
- chardet=4.0.0=py38h578d9bd_1
|
||||
- cryptography=3.4.7=py38ha5dfef3_0
|
||||
- cudatoolkit=11.1.74=h6bb024c_0
|
||||
- ffmpeg=4.3=hf484d3e_0
|
||||
- freetype=2.10.4=h5ab3b9f_0
|
||||
- gmp=6.2.1=h2531618_2
|
||||
- gnutls=3.6.15=he1e5248_0
|
||||
- googledrivedownloader=0.4=pyhd3deb0d_1
|
||||
- idna=2.10=pyh9f0ad1d_0
|
||||
- intel-openmp=2021.2.0=h06a4308_610
|
||||
- jinja2=3.0.1=pyhd8ed1ab_0
|
||||
- joblib=1.0.1=pyhd8ed1ab_0
|
||||
- jpeg=9b=h024ee3a_2
|
||||
- lame=3.100=h7b6447c_0
|
||||
- lcms2=2.12=h3be6417_0
|
||||
- ld_impl_linux-64=2.35.1=h7274673_9
|
||||
- libffi=3.3=he6710b0_2
|
||||
- libgcc-ng=9.3.0=h5101ec6_17
|
||||
- libgfortran-ng=7.5.0=h14aa051_19
|
||||
- libgfortran4=7.5.0=h14aa051_19
|
||||
- libgomp=9.3.0=h5101ec6_17
|
||||
- libiconv=1.15=h63c8f33_5
|
||||
- libidn2=2.3.1=h27cfd23_0
|
||||
- libpng=1.6.37=hbc83047_0
|
||||
- libstdcxx-ng=9.3.0=hd4cf53a_17
|
||||
- libtasn1=4.16.0=h27cfd23_0
|
||||
- libtiff=4.2.0=h85742a9_0
|
||||
- libunistring=0.9.10=h27cfd23_0
|
||||
- libuv=1.40.0=h7b6447c_0
|
||||
- libwebp-base=1.2.0=h27cfd23_0
|
||||
- lz4-c=1.9.3=h2531618_0
|
||||
- markupsafe=2.0.1=py38h497a2fe_0
|
||||
- mkl=2021.2.0=h06a4308_296
|
||||
- mkl-service=2.3.0=py38h27cfd23_1
|
||||
- mkl_fft=1.3.0=py38h42c9631_2
|
||||
- mkl_random=1.2.1=py38ha9443f7_2
|
||||
- ncurses=6.2=he6710b0_1
|
||||
- nettle=3.7.3=hbbd107a_1
|
||||
- networkx=2.5=py_0
|
||||
- ninja=1.10.2=hff7bd54_1
|
||||
- numpy=1.20.2=py38h2d18471_0
|
||||
- numpy-base=1.20.2=py38hfae3a4d_0
|
||||
- olefile=0.46=py_0
|
||||
- openh264=2.1.0=hd408876_0
|
||||
- openssl=1.1.1k=h27cfd23_0
|
||||
- packaging=20.9=pyhd3eb1b0_0
|
||||
- pandas=1.2.5=py38h1abd341_0
|
||||
- pillow=8.2.0=py38he98fc37_0
|
||||
- pip=21.1.3=py38h06a4308_0
|
||||
- pycparser=2.20=pyh9f0ad1d_2
|
||||
- pyopenssl=20.0.1=pyhd8ed1ab_0
|
||||
- pyparsing=2.4.7=pyh9f0ad1d_0
|
||||
- pysocks=1.7.1=py38h578d9bd_3
|
||||
- python=3.8.10=h12debd9_8
|
||||
- python-dateutil=2.8.1=py_0
|
||||
- python-louvain=0.15=pyhd3deb0d_0
|
||||
- python_abi=3.8=2_cp38
|
||||
- pytorch=1.9.0=py3.8_cuda11.1_cudnn8.0.5_0
|
||||
- pytorch-cluster=1.5.9=py38_torch_1.9.0_cu111
|
||||
- pytorch-geometric=1.7.2=py38_torch_1.9.0_cu111
|
||||
- pytorch-scatter=2.0.7=py38_torch_1.9.0_cu111
|
||||
- pytorch-sparse=0.6.10=py38_torch_1.9.0_cu111
|
||||
- pytorch-spline-conv=1.2.1=py38_torch_1.9.0_cu111
|
||||
- pytz=2021.1=pyhd8ed1ab_0
|
||||
- readline=8.1=h27cfd23_0
|
||||
- requests=2.25.1=pyhd3deb0d_0
|
||||
- scikit-learn=0.24.2=py38ha9443f7_0
|
||||
- setuptools=52.0.0=py38h06a4308_0
|
||||
- six=1.16.0=pyhd3eb1b0_0
|
||||
- sqlite=3.36.0=hc218d9a_0
|
||||
- threadpoolctl=2.1.0=pyh5ca1d4c_0
|
||||
- tk=8.6.10=hbc83047_0
|
||||
- torchaudio=0.9.0=py38
|
||||
- torchvision=0.10.0=py38_cu111
|
||||
- tqdm=4.61.1=pyhd8ed1ab_0
|
||||
- typing_extensions=3.10.0.0=pyh06a4308_0
|
||||
- urllib3=1.26.6=pyhd8ed1ab_0
|
||||
- wheel=0.36.2=pyhd3eb1b0_0
|
||||
- xz=5.2.5=h7b6447c_0
|
||||
- zlib=1.2.11=h7b6447c_3
|
||||
- zstd=1.4.9=haebb681_0
|
||||
- pip:
|
||||
- decorator==4.4.2
|
||||
- dgl-cu110==0.6.1
|
||||
- scipy==1.7.0
|
||||
- lie-learn==0.0.1.post1
|
2
DGLPyTorch/DrugDiscovery/RoseTTAFold/example/input.fa
Normal file
2
DGLPyTorch/DrugDiscovery/RoseTTAFold/example/input.fa
Normal file
|
@ -0,0 +1,2 @@
|
|||
>T1078 Tsp1, Trichoderma virens, 138 residues|
|
||||
MAAPTPADKSMMAAVPEWTITNLKRVCNAGNTSCTWTFGVDTHLATATSCTYVVKANANASQASGGPVTCGPYTITSSWSGQFGPNNGFTTFAVTDFSKKLIVWPAYTDVQVQAGKVVSPNQSYAPANLPLEHHHHHH
|
9
DGLPyTorch/DrugDiscovery/RoseTTAFold/folding-linux.yml
Normal file
9
DGLPyTorch/DrugDiscovery/RoseTTAFold/folding-linux.yml
Normal file
|
@ -0,0 +1,9 @@
|
|||
name: folding
|
||||
channels:
|
||||
- defaults
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- tensorflow-gpu=1.14
|
||||
- pandas
|
||||
- scikit-learn=0.24
|
||||
- parallel
|
43
DGLPyTorch/DrugDiscovery/RoseTTAFold/folding/README
Normal file
43
DGLPyTorch/DrugDiscovery/RoseTTAFold/folding/README
Normal file
|
@ -0,0 +1,43 @@
|
|||
###############################################################################
|
||||
RosettaTR.py: Relax on dualspace (2 rounds of relax w/ different set of
|
||||
restraints)
|
||||
###############################################################################
|
||||
|
||||
|
||||
usage: RosettaTR.py [-h] [-r NRESTARTS] [-pd PCUT] [-m {0,1,2}] [-bb BB]
|
||||
[-sg SG] [-n STEPS] [--save_chk] [--orient] [--no-orient]
|
||||
[--fastrelax] [--no-fastrelax] [--roll] [--no-roll]
|
||||
NPZ FASTA OUT
|
||||
|
||||
positional arguments:
|
||||
NPZ input distograms and anglegrams (NN predictions)
|
||||
FASTA input sequence
|
||||
OUT output model (in PDB format)
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
-r NRESTARTS number of noisy restrarts (default: 3)
|
||||
-pd PCUT min probability of distance restraints (default: 0.05)
|
||||
-m {0,1,2} 0: sh+m+l, 1: (sh+m)+l, 2: (sh+m+l) (default: 2)
|
||||
-bb BB predicted backbone torsions (default: )
|
||||
-sg SG window size and order for a Savitzky-Golay filter (comma-
|
||||
separated) (default: )
|
||||
-n STEPS number of minimization steps (default: 1000)
|
||||
--save_chk save checkpoint files to restart (default: False)
|
||||
--orient use orientations (default: True)
|
||||
--no-orient
|
||||
--fastrelax perform FastRelax (default: True)
|
||||
--no-fastrelax
|
||||
--roll circularly shift 6d coordinate arrays by 1 (default: False)
|
||||
--no-roll
|
||||
|
||||
|
||||
# USAGE
|
||||
conda activate folding
|
||||
|
||||
# try: -m 0,1,2
|
||||
# -pd 0.05, 0.15, 0.25, 0.35, 0.45
|
||||
# repeat ~3-5 times for every combination of -m and -pd
|
||||
# !!! use '--roll' option if no-contact bin is the last one !!!
|
||||
python ./RosettaTR.py -m 0 -pd 0.15 fake.npz fake.fa model.pdb
|
||||
|
340
DGLPyTorch/DrugDiscovery/RoseTTAFold/folding/RosettaTR.py
Normal file
340
DGLPyTorch/DrugDiscovery/RoseTTAFold/folding/RosettaTR.py
Normal file
|
@ -0,0 +1,340 @@
|
|||
import sys,os,json
|
||||
import tempfile
|
||||
import numpy as np
|
||||
|
||||
from arguments import *
|
||||
from utils import *
|
||||
from pyrosetta import *
|
||||
from pyrosetta.rosetta.protocols.minimization_packing import MinMover
|
||||
|
||||
vdw_weight = {0: 3.0, 1: 5.0, 2: 10.0}
|
||||
rsr_dist_weight = {0: 3.0, 1: 2.0, 3: 1.0}
|
||||
rsr_orient_weight = {0: 1.0, 1: 1.0, 3: 0.5}
|
||||
|
||||
def main():
|
||||
|
||||
########################################################
|
||||
# process inputs
|
||||
########################################################
|
||||
|
||||
# read params
|
||||
scriptdir = os.path.dirname(os.path.realpath(__file__))
|
||||
with open(scriptdir + '/data/params.json') as jsonfile:
|
||||
params = json.load(jsonfile)
|
||||
|
||||
# get command line arguments
|
||||
args = get_args(params)
|
||||
print(args)
|
||||
if os.path.exists(args.OUT):
|
||||
return
|
||||
|
||||
# init PyRosetta
|
||||
init_cmd = list()
|
||||
init_cmd.append("-multithreading:interaction_graph_threads 1 -multithreading:total_threads 1")
|
||||
init_cmd.append("-hb_cen_soft")
|
||||
init_cmd.append("-detect_disulf -detect_disulf_tolerance 2.0") # detect disulfide bonds based on Cb-Cb distance (CEN mode) or SG-SG distance (FA mode)
|
||||
init_cmd.append("-relax:dualspace true -relax::minimize_bond_angles -default_max_cycles 200")
|
||||
init_cmd.append("-mute all")
|
||||
init(" ".join(init_cmd))
|
||||
|
||||
# read and process restraints & sequence
|
||||
seq = read_fasta(args.FASTA)
|
||||
L = len(seq)
|
||||
params['seq'] = seq
|
||||
rst = gen_rst(params)
|
||||
|
||||
########################################################
|
||||
# Scoring functions and movers
|
||||
########################################################
|
||||
sf = ScoreFunction()
|
||||
sf.add_weights_from_file(scriptdir + '/data/scorefxn.wts')
|
||||
|
||||
sf1 = ScoreFunction()
|
||||
sf1.add_weights_from_file(scriptdir + '/data/scorefxn1.wts')
|
||||
|
||||
sf_vdw = ScoreFunction()
|
||||
sf_vdw.add_weights_from_file(scriptdir + '/data/scorefxn_vdw.wts')
|
||||
|
||||
sf_cart = ScoreFunction()
|
||||
sf_cart.add_weights_from_file(scriptdir + '/data/scorefxn_cart.wts')
|
||||
|
||||
mmap = MoveMap()
|
||||
mmap.set_bb(True)
|
||||
mmap.set_chi(False)
|
||||
mmap.set_jump(True)
|
||||
|
||||
min_mover1 = MinMover(mmap, sf1, 'lbfgs_armijo_nonmonotone', 0.001, True)
|
||||
min_mover1.max_iter(1000)
|
||||
|
||||
min_mover_vdw = MinMover(mmap, sf_vdw, 'lbfgs_armijo_nonmonotone', 0.001, True)
|
||||
min_mover_vdw.max_iter(500)
|
||||
|
||||
min_mover_cart = MinMover(mmap, sf_cart, 'lbfgs_armijo_nonmonotone', 0.000001, True)
|
||||
min_mover_cart.max_iter(300)
|
||||
min_mover_cart.cartesian(True)
|
||||
|
||||
|
||||
if not os.path.exists("%s_before_relax.pdb"%('.'.join(args.OUT.split('.')[:-1]))):
|
||||
########################################################
|
||||
# initialize pose
|
||||
########################################################
|
||||
pose0 = pose_from_sequence(seq, 'centroid')
|
||||
|
||||
# mutate GLY to ALA
|
||||
for i,a in enumerate(seq):
|
||||
if a == 'G':
|
||||
mutator = rosetta.protocols.simple_moves.MutateResidue(i+1,'ALA')
|
||||
mutator.apply(pose0)
|
||||
print('mutation: G%dA'%(i+1))
|
||||
|
||||
if (args.bb == ''):
|
||||
print('setting random (phi,psi,omega)...')
|
||||
set_random_dihedral(pose0)
|
||||
else:
|
||||
print('setting predicted (phi,psi,omega)...')
|
||||
bb = np.load(args.bb)
|
||||
set_predicted_dihedral(pose0,bb['phi'],bb['psi'],bb['omega'])
|
||||
|
||||
remove_clash(sf_vdw, min_mover_vdw, pose0)
|
||||
|
||||
Emin = 99999.9
|
||||
|
||||
########################################################
|
||||
# minimization
|
||||
########################################################
|
||||
|
||||
for run in range(params['NRUNS']):
|
||||
# define repeat_mover here!! (update vdw weights: weak (1.0) -> strong (10.0)
|
||||
sf.set_weight(rosetta.core.scoring.vdw, vdw_weight.setdefault(run, 10.0))
|
||||
sf.set_weight(rosetta.core.scoring.atom_pair_constraint, rsr_dist_weight.setdefault(run, 1.0))
|
||||
sf.set_weight(rosetta.core.scoring.dihedral_constraint, rsr_orient_weight.setdefault(run, 0.5))
|
||||
sf.set_weight(rosetta.core.scoring.angle_constraint, rsr_orient_weight.setdefault(run, 0.5))
|
||||
|
||||
min_mover = MinMover(mmap, sf, 'lbfgs_armijo_nonmonotone', 0.001, True)
|
||||
min_mover.max_iter(1000)
|
||||
|
||||
repeat_mover = RepeatMover(min_mover, 3)
|
||||
|
||||
#
|
||||
pose = Pose()
|
||||
pose.assign(pose0)
|
||||
pose.remove_constraints()
|
||||
|
||||
if run > 0:
|
||||
|
||||
# diversify backbone
|
||||
dphi = np.random.uniform(-10,10,L)
|
||||
dpsi = np.random.uniform(-10,10,L)
|
||||
for i in range(1,L+1):
|
||||
pose.set_phi(i,pose.phi(i)+dphi[i-1])
|
||||
pose.set_psi(i,pose.psi(i)+dpsi[i-1])
|
||||
|
||||
# remove clashes
|
||||
remove_clash(sf_vdw, min_mover_vdw, pose)
|
||||
|
||||
# Save checkpoint
|
||||
if args.save_chk:
|
||||
pose.dump_pdb("%s_run%d_init.pdb"%('.'.join(args.OUT.split('.')[:-1]), run))
|
||||
|
||||
if args.mode == 0:
|
||||
|
||||
# short
|
||||
print('short')
|
||||
add_rst(pose, rst, 3, 12, params)
|
||||
repeat_mover.apply(pose)
|
||||
remove_clash(sf_vdw, min_mover1, pose)
|
||||
min_mover_cart.apply(pose)
|
||||
if args.save_chk:
|
||||
pose.dump_pdb("%s_run%d_mode%d_step%d.pdb"%('.'.join(args.OUT.split('.')[:-1]), run, args.mode, 0))
|
||||
|
||||
# medium
|
||||
print('medium')
|
||||
add_rst(pose, rst, 12, 24, params)
|
||||
repeat_mover.apply(pose)
|
||||
remove_clash(sf_vdw, min_mover1, pose)
|
||||
min_mover_cart.apply(pose)
|
||||
if args.save_chk:
|
||||
pose.dump_pdb("%s_run%d_mode%d_step%d.pdb"%('.'.join(args.OUT.split('.')[:-1]), run, args.mode, 1))
|
||||
|
||||
# long
|
||||
print('long')
|
||||
add_rst(pose, rst, 24, len(seq), params)
|
||||
repeat_mover.apply(pose)
|
||||
remove_clash(sf_vdw, min_mover1, pose)
|
||||
min_mover_cart.apply(pose)
|
||||
if args.save_chk:
|
||||
pose.dump_pdb("%s_run%d_mode%d_step%d.pdb"%('.'.join(args.OUT.split('.')[:-1]), run, args.mode, 2))
|
||||
|
||||
elif args.mode == 1:
|
||||
|
||||
# short + medium
|
||||
print('short + medium')
|
||||
add_rst(pose, rst, 3, 24, params)
|
||||
repeat_mover.apply(pose)
|
||||
remove_clash(sf_vdw, min_mover1, pose)
|
||||
min_mover_cart.apply(pose)
|
||||
if args.save_chk:
|
||||
pose.dump_pdb("%s_run%d_mode%d_step%d.pdb"%('.'.join(args.OUT.split('.')[:-1]), run, args.mode, 0))
|
||||
|
||||
# long
|
||||
print('long')
|
||||
add_rst(pose, rst, 24, len(seq), params)
|
||||
repeat_mover.apply(pose)
|
||||
remove_clash(sf_vdw, min_mover1, pose)
|
||||
min_mover_cart.apply(pose)
|
||||
if args.save_chk:
|
||||
pose.dump_pdb("%s_run%d_mode%d_step%d.pdb"%('.'.join(args.OUT.split('.')[:-1]), run, args.mode, 1))
|
||||
|
||||
elif args.mode == 2:
|
||||
|
||||
# short + medium + long
|
||||
print('short + medium + long')
|
||||
add_rst(pose, rst, 3, len(seq), params)
|
||||
repeat_mover.apply(pose)
|
||||
remove_clash(sf_vdw, min_mover1, pose)
|
||||
min_mover_cart.apply(pose)
|
||||
if args.save_chk:
|
||||
pose.dump_pdb("%s_run%d_mode%d_step%d.pdb"%('.'.join(args.OUT.split('.')[:-1]), run, args.mode, 0))
|
||||
|
||||
# check whether energy has decreased
|
||||
pose.conformation().detect_disulfides() # detect disulfide bonds
|
||||
E = sf_cart(pose)
|
||||
if E < Emin:
|
||||
print("Energy(iter=%d): %.1f --> %.1f (accept)"%(run, Emin, E))
|
||||
Emin = E
|
||||
pose0.assign(pose)
|
||||
else:
|
||||
print("Energy(iter=%d): %.1f --> %.1f (reject)"%(run, Emin, E))
|
||||
|
||||
# mutate ALA back to GLY
|
||||
for i,a in enumerate(seq):
|
||||
if a == 'G':
|
||||
mutator = rosetta.protocols.simple_moves.MutateResidue(i+1,'GLY')
|
||||
mutator.apply(pose0)
|
||||
print('mutation: A%dG'%(i+1))
|
||||
|
||||
########################################################
|
||||
# fix backbone geometry
|
||||
########################################################
|
||||
pose0.remove_constraints()
|
||||
|
||||
# apply more strict criteria to detect disulfide bond
|
||||
# Set options for disulfide tolerance -> 1.0A
|
||||
print (rosetta.basic.options.get_real_option('in:detect_disulf_tolerance'))
|
||||
rosetta.basic.options.set_real_option('in:detect_disulf_tolerance', 1.0)
|
||||
print (rosetta.basic.options.get_real_option('in:detect_disulf_tolerance'))
|
||||
pose0.conformation().detect_disulfides()
|
||||
|
||||
# Converto to all atom representation
|
||||
switch = SwitchResidueTypeSetMover("fa_standard")
|
||||
switch.apply(pose0)
|
||||
|
||||
# idealize problematic local regions if exists
|
||||
idealize = rosetta.protocols.idealize.IdealizeMover()
|
||||
poslist = rosetta.utility.vector1_unsigned_long()
|
||||
|
||||
scorefxn=create_score_function('empty')
|
||||
scorefxn.set_weight(rosetta.core.scoring.cart_bonded, 1.0)
|
||||
scorefxn.score(pose0)
|
||||
|
||||
emap = pose0.energies()
|
||||
print("idealize...")
|
||||
for res in range(1,L+1):
|
||||
cart = emap.residue_total_energy(res)
|
||||
if cart > 50:
|
||||
poslist.append(res)
|
||||
print( "idealize %d %8.3f"%(res,cart) )
|
||||
|
||||
if len(poslist) > 0:
|
||||
idealize.set_pos_list(poslist)
|
||||
try:
|
||||
idealize.apply(pose0)
|
||||
|
||||
except:
|
||||
print('!!! idealization failed !!!')
|
||||
|
||||
# Save checkpoint
|
||||
if args.save_chk:
|
||||
pose0.dump_pdb("%s_before_relax.pdb"%'.'.join(args.OUT.split('.')[:-1]))
|
||||
|
||||
else: # checkpoint exists
|
||||
pose0 = pose_from_file("%s_before_relax.pdb"%('.'.join(args.OUT.split('.')[:-1])))
|
||||
#
|
||||
print ("to centroid")
|
||||
switch_cen = SwitchResidueTypeSetMover("centroid")
|
||||
switch_cen.apply(pose0)
|
||||
#
|
||||
print ("detect disulfide bonds")
|
||||
# Set options for disulfide tolerance -> 1.0A
|
||||
print (rosetta.basic.options.get_real_option('in:detect_disulf_tolerance'))
|
||||
rosetta.basic.options.set_real_option('in:detect_disulf_tolerance', 1.0)
|
||||
print (rosetta.basic.options.get_real_option('in:detect_disulf_tolerance'))
|
||||
pose0.conformation().detect_disulfides()
|
||||
#
|
||||
print ("to all atom")
|
||||
switch = SwitchResidueTypeSetMover("fa_standard")
|
||||
switch.apply(pose0)
|
||||
|
||||
|
||||
########################################################
|
||||
# full-atom refinement
|
||||
########################################################
|
||||
|
||||
if args.fastrelax == True:
|
||||
mmap = MoveMap()
|
||||
mmap.set_bb(True)
|
||||
mmap.set_chi(True)
|
||||
mmap.set_jump(True)
|
||||
|
||||
# First round: Repeat 2 torsion space relax w/ strong disto/anglogram constraints
|
||||
sf_fa_round1 = create_score_function('ref2015_cart')
|
||||
sf_fa_round1.set_weight(rosetta.core.scoring.atom_pair_constraint, 3.0)
|
||||
sf_fa_round1.set_weight(rosetta.core.scoring.dihedral_constraint, 1.0)
|
||||
sf_fa_round1.set_weight(rosetta.core.scoring.angle_constraint, 1.0)
|
||||
sf_fa_round1.set_weight(rosetta.core.scoring.pro_close, 0.0)
|
||||
|
||||
relax_round1 = rosetta.protocols.relax.FastRelax(sf_fa_round1, "%s/data/relax_round1.txt"%scriptdir)
|
||||
relax_round1.set_movemap(mmap)
|
||||
|
||||
print('relax: First round... (focused on torsion space relaxation)')
|
||||
params['PCUT'] = 0.15
|
||||
pose0.remove_constraints()
|
||||
add_rst(pose0, rst, 3, len(seq), params, nogly=True, use_orient=True)
|
||||
relax_round1.apply(pose0)
|
||||
|
||||
# Set options for disulfide tolerance -> 0.5A
|
||||
print (rosetta.basic.options.get_real_option('in:detect_disulf_tolerance'))
|
||||
rosetta.basic.options.set_real_option('in:detect_disulf_tolerance', 0.5)
|
||||
print (rosetta.basic.options.get_real_option('in:detect_disulf_tolerance'))
|
||||
|
||||
sf_fa = create_score_function('ref2015_cart')
|
||||
sf_fa.set_weight(rosetta.core.scoring.atom_pair_constraint, 0.1)
|
||||
sf_fa.set_weight(rosetta.core.scoring.dihedral_constraint, 0.0)
|
||||
sf_fa.set_weight(rosetta.core.scoring.angle_constraint, 0.0)
|
||||
|
||||
relax_round2 = rosetta.protocols.relax.FastRelax(sf_fa, "%s/data/relax_round2.txt"%scriptdir)
|
||||
relax_round2.set_movemap(mmap)
|
||||
relax_round2.cartesian(True)
|
||||
relax_round2.dualspace(True)
|
||||
|
||||
print('relax: Second round... (cartesian space)')
|
||||
params['PCUT'] = 0.30 # To reduce the number of pair restraints..
|
||||
pose0.remove_constraints()
|
||||
pose0.conformation().detect_disulfides() # detect disulfide bond again w/ stricter cutoffs
|
||||
# To reduce the number of constraints, only pair distances are considered w/ higher prob cutoffs
|
||||
add_rst(pose0, rst, 3, len(seq), params, nogly=True, use_orient=False, p12_cut=params['PCUT'])
|
||||
# Instead, apply CA coordinate constraints to prevent drifting away too much (focus on local refinement?)
|
||||
add_crd_rst(pose0, L, std=1.0, tol=2.0)
|
||||
relax_round2.apply(pose0)
|
||||
|
||||
# Re-evaluate score w/o any constraints
|
||||
scorefxn_min=create_score_function('ref2015_cart')
|
||||
scorefxn_min.score(pose0)
|
||||
|
||||
########################################################
|
||||
# save final model
|
||||
########################################################
|
||||
pose0.dump_pdb(args.OUT)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
38
DGLPyTorch/DrugDiscovery/RoseTTAFold/folding/arguments.py
Normal file
38
DGLPyTorch/DrugDiscovery/RoseTTAFold/folding/arguments.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
import argparse
|
||||
|
||||
def get_args(params):
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("NPZ", type=str, help="input distograms and anglegrams (NN predictions)")
|
||||
parser.add_argument("FASTA", type=str, help="input sequence")
|
||||
parser.add_argument("OUT", type=str, help="output model (in PDB format)")
|
||||
|
||||
parser.add_argument('-r', type=int, dest='nrestarts', default=params['NRUNS'], help='number of noisy restrarts')
|
||||
parser.add_argument('-pd', type=float, dest='pcut', default=params['PCUT'], help='min probability of distance restraints')
|
||||
parser.add_argument('-m', type=int, dest='mode', default=2, choices=[0,1,2], help='0: sh+m+l, 1: (sh+m)+l, 2: (sh+m+l)')
|
||||
parser.add_argument('-bb', type=str, dest='bb', default='', help='predicted backbone torsions')
|
||||
parser.add_argument('-sg', type=str, dest='sg', default='', help='window size and order for a Savitzky-Golay filter (comma-separated)')
|
||||
parser.add_argument('-n', type=int, dest='steps', default=1000, help='number of minimization steps')
|
||||
parser.add_argument('--save_chk', dest='save_chk', default=False, action='store_true', help='save checkpoint files to restart')
|
||||
parser.add_argument('--orient', dest='use_orient', action='store_true', help='use orientations')
|
||||
parser.add_argument('--no-orient', dest='use_orient', action='store_false')
|
||||
parser.add_argument('--fastrelax', dest='fastrelax', action='store_true', help='perform FastRelax')
|
||||
parser.add_argument('--no-fastrelax', dest='fastrelax', action='store_false')
|
||||
parser.add_argument('--roll', dest='roll', action='store_true', help='circularly shift 6d coordinate arrays by 1')
|
||||
parser.add_argument('--no-roll', dest='roll', action='store_false')
|
||||
parser.set_defaults(use_orient=True)
|
||||
parser.set_defaults(fastrelax=True)
|
||||
parser.set_defaults(roll=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
params['PCUT'] = args.pcut
|
||||
params['USE_ORIENT'] = args.use_orient
|
||||
params['NRUNS'] = args.nrestarts
|
||||
params['ROLL'] = args.roll
|
||||
|
||||
params['NPZ'] = args.NPZ
|
||||
params['BB'] = args.bb
|
||||
params['SG'] = args.sg
|
||||
|
||||
return args
|
|
@ -0,0 +1,17 @@
|
|||
{
|
||||
"PCUT" : 0.05,
|
||||
"PCUT1" : 0.5,
|
||||
"EBASE" : -0.5,
|
||||
"EREP" : [10.0,3.0,0.5],
|
||||
"DREP" : [0.0,2.0,3.5],
|
||||
"PREP" : 0.1,
|
||||
"SIGD" : 10.0,
|
||||
"SIGM" : 1.0,
|
||||
"MEFF" : 0.0001,
|
||||
"DCUT" : 19.5,
|
||||
"ALPHA" : 1.57,
|
||||
"DSTEP" : 0.5,
|
||||
"ASTEP" : 15.0,
|
||||
"BBWGHT" : 10.0,
|
||||
"NRUNS" : 3
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
switch:torsion
|
||||
repeat 2
|
||||
ramp_repack_min 0.02 0.01 1.0 100
|
||||
ramp_repack_min 0.250 0.01 0.5 100
|
||||
ramp_repack_min 0.550 0.01 0.1 100
|
||||
ramp_repack_min 1 0.00001 0.1 100
|
||||
accept_to_best
|
||||
endrepeat
|
||||
|
||||
switch:cartesian
|
||||
repeat 1
|
||||
ramp_repack_min 0.02 0.01 1.0 50
|
||||
ramp_repack_min 0.250 0.01 0.5 50
|
||||
ramp_repack_min 0.550 0.01 0.1 100
|
||||
ramp_repack_min 1 0.00001 0.1 200
|
||||
accept_to_best
|
||||
endrepeat
|
|
@ -0,0 +1,8 @@
|
|||
switch:cartesian
|
||||
repeat 2
|
||||
ramp_repack_min 0.02 0.01 1.0 50
|
||||
ramp_repack_min 0.250 0.01 0.5 50
|
||||
ramp_repack_min 0.550 0.01 0.1 100
|
||||
ramp_repack_min 1 0.00001 0.1 200
|
||||
accept_to_best
|
||||
endrepeat
|
|
@ -0,0 +1,7 @@
|
|||
cen_hb 5.0
|
||||
rama 1.0
|
||||
omega 0.5
|
||||
vdw 1.0
|
||||
atom_pair_constraint 5.0
|
||||
dihedral_constraint 4.0
|
||||
angle_constraint 4.0
|
|
@ -0,0 +1,7 @@
|
|||
cen_hb 5.0
|
||||
rama 1.0
|
||||
omega 0.5
|
||||
vdw 5.0
|
||||
atom_pair_constraint 1.0
|
||||
dihedral_constraint 0.5
|
||||
angle_constraint 0.5
|
|
@ -0,0 +1,9 @@
|
|||
hbond_sr_bb 3.0
|
||||
hbond_lr_bb 3.0
|
||||
rama 1.0
|
||||
omega 0.5
|
||||
vdw 5.0
|
||||
cart_bonded 5.0
|
||||
atom_pair_constraint 1.0
|
||||
dihedral_constraint 0.5
|
||||
angle_constraint 0.5
|
|
@ -0,0 +1,2 @@
|
|||
rama 1.0
|
||||
vdw 1.0
|
373
DGLPyTorch/DrugDiscovery/RoseTTAFold/folding/utils.py
Normal file
373
DGLPyTorch/DrugDiscovery/RoseTTAFold/folding/utils.py
Normal file
|
@ -0,0 +1,373 @@
|
|||
import numpy as np
|
||||
import random
|
||||
import scipy
|
||||
from scipy.signal import *
|
||||
from pyrosetta import *
|
||||
|
||||
eps = 1e-9
|
||||
P_ADD_OMEGA = 0.5
|
||||
P_ADD_THETA = 0.5
|
||||
P_ADD_PHI = 0.6
|
||||
|
||||
def gen_rst(params):
|
||||
|
||||
npz = np.load(params['NPZ'])
|
||||
|
||||
dist,omega,theta,phi = npz['dist'],npz['omega'],npz['theta'],npz['phi']
|
||||
|
||||
if params['ROLL']==True:
|
||||
print("Apply circular shift...")
|
||||
dist = np.roll(dist,1,axis=-1)
|
||||
omega = np.roll(omega,1,axis=-1)
|
||||
theta = np.roll(theta,1,axis=-1)
|
||||
phi = np.roll(phi,1,axis=-1)
|
||||
|
||||
dist = dist.astype(np.float32) + eps
|
||||
omega = omega.astype(np.float32) + eps
|
||||
theta = theta.astype(np.float32) + eps
|
||||
phi = phi.astype(np.float32) + eps
|
||||
|
||||
# dictionary to store Rosetta restraints
|
||||
rst = {'dist' : [], 'omega' : [], 'theta' : [], 'phi' : []}
|
||||
|
||||
########################################################
|
||||
# assign parameters
|
||||
########################################################
|
||||
PCUT = 0.05 #params['PCUT']
|
||||
EBASE = params['EBASE']
|
||||
EREP = params['EREP']
|
||||
DREP = params['DREP']
|
||||
PREP = params['PREP']
|
||||
SIGD = params['SIGD']
|
||||
SIGM = params['SIGM']
|
||||
MEFF = params['MEFF']
|
||||
DCUT = params['DCUT']
|
||||
ALPHA = params['ALPHA']
|
||||
BBWGHT = params['BBWGHT']
|
||||
|
||||
DSTEP = params['DSTEP']
|
||||
ASTEP = np.deg2rad(params['ASTEP'])
|
||||
|
||||
seq = params['seq']
|
||||
|
||||
sg_flag = False
|
||||
if params['SG'] != '':
|
||||
sg_flag = True
|
||||
sg_w,sg_n = [int(v) for v in params['SG'].split(",")]
|
||||
print("Savitzky-Golay: %d,%d"%(sg_w,sg_n))
|
||||
|
||||
########################################################
|
||||
# dist: 0..20A
|
||||
########################################################
|
||||
nres = dist.shape[0]
|
||||
bins = np.array([4.25+DSTEP*i for i in range(32)])
|
||||
prob = np.sum(dist[:,:,5:], axis=-1) # prob of dist within 20A
|
||||
prob_12 = np.sum(dist[:,:,5:21], axis=-1) # prob of dist within 12A
|
||||
bkgr = np.array((bins/DCUT)**ALPHA)
|
||||
attr = -np.log((dist[:,:,5:]+MEFF)/(dist[:,:,-1][:,:,None]*bkgr[None,None,:]))+EBASE
|
||||
repul = np.maximum(attr[:,:,0],np.zeros((nres,nres)))[:,:,None]+np.array(EREP)[None,None,:]
|
||||
dist = np.concatenate([repul,attr], axis=-1)
|
||||
bins = np.concatenate([DREP,bins])
|
||||
x = pyrosetta.rosetta.utility.vector1_double()
|
||||
_ = [x.append(v) for v in bins]
|
||||
#
|
||||
prob = np.triu(prob, k=1) # fill zeros to diagonal and lower (for speed-up)
|
||||
i,j = np.where(prob>PCUT)
|
||||
prob = prob[i,j]
|
||||
prob_12 = prob_12[i,j]
|
||||
#nbins = 35
|
||||
step = 0.5
|
||||
for a,b,p,p_12 in zip(i,j,prob,prob_12):
|
||||
y = pyrosetta.rosetta.utility.vector1_double()
|
||||
if sg_flag == True:
|
||||
_ = [y.append(v) for v in savgol_filter(dist[a,b],sg_w,sg_n)]
|
||||
else:
|
||||
_ = [y.append(v) for v in dist[a,b]]
|
||||
spline = rosetta.core.scoring.func.SplineFunc("", 1.0, 0.0, step, x,y)
|
||||
ida = rosetta.core.id.AtomID(5,a+1)
|
||||
idb = rosetta.core.id.AtomID(5,b+1)
|
||||
rst['dist'].append([a,b,p,p_12,rosetta.core.scoring.constraints.AtomPairConstraint(ida, idb, spline)])
|
||||
print("dist restraints: %d"%(len(rst['dist'])))
|
||||
|
||||
|
||||
########################################################
|
||||
# omega: -pi..pi
|
||||
########################################################
|
||||
nbins = omega.shape[2]-1
|
||||
ASTEP = 2.0*np.pi/nbins
|
||||
nbins += 4
|
||||
bins = np.linspace(-np.pi-1.5*ASTEP, np.pi+1.5*ASTEP, nbins)
|
||||
x = pyrosetta.rosetta.utility.vector1_double()
|
||||
_ = [x.append(v) for v in bins]
|
||||
prob = np.sum(omega[:,:,1:], axis=-1)
|
||||
prob = np.triu(prob, k=1) # fill zeros to diagonal and lower (for speed-up)
|
||||
i,j = np.where(prob>PCUT+P_ADD_OMEGA)
|
||||
prob = prob[i,j]
|
||||
omega = -np.log((omega+MEFF)/(omega[:,:,-1]+MEFF)[:,:,None])
|
||||
#if sg_flag == True:
|
||||
# omega = savgol_filter(omega,sg_w,sg_n,axis=-1,mode='wrap')
|
||||
omega = np.concatenate([omega[:,:,-2:],omega[:,:,1:],omega[:,:,1:3]],axis=-1)
|
||||
for a,b,p in zip(i,j,prob):
|
||||
y = pyrosetta.rosetta.utility.vector1_double()
|
||||
_ = [y.append(v) for v in omega[a,b]]
|
||||
spline = rosetta.core.scoring.func.SplineFunc("", 1.0, 0.0, ASTEP, x,y)
|
||||
id1 = rosetta.core.id.AtomID(2,a+1) # CA-i
|
||||
id2 = rosetta.core.id.AtomID(5,a+1) # CB-i
|
||||
id3 = rosetta.core.id.AtomID(5,b+1) # CB-j
|
||||
id4 = rosetta.core.id.AtomID(2,b+1) # CA-j
|
||||
rst['omega'].append([a,b,p,rosetta.core.scoring.constraints.DihedralConstraint(id1,id2,id3,id4, spline)])
|
||||
print("omega restraints: %d"%(len(rst['omega'])))
|
||||
|
||||
|
||||
########################################################
|
||||
# theta: -pi..pi
|
||||
########################################################
|
||||
prob = np.sum(theta[:,:,1:], axis=-1)
|
||||
np.fill_diagonal(prob, 0.0)
|
||||
i,j = np.where(prob>PCUT+P_ADD_THETA)
|
||||
prob = prob[i,j]
|
||||
theta = -np.log((theta+MEFF)/(theta[:,:,-1]+MEFF)[:,:,None])
|
||||
#if sg_flag == True:
|
||||
# theta = savgol_filter(theta,sg_w,sg_n,axis=-1,mode='wrap')
|
||||
theta = np.concatenate([theta[:,:,-2:],theta[:,:,1:],theta[:,:,1:3]],axis=-1)
|
||||
for a,b,p in zip(i,j,prob):
|
||||
y = pyrosetta.rosetta.utility.vector1_double()
|
||||
_ = [y.append(v) for v in theta[a,b]]
|
||||
spline = rosetta.core.scoring.func.SplineFunc("", 1.0, 0.0, ASTEP, x,y)
|
||||
id1 = rosetta.core.id.AtomID(1,a+1) # N-i
|
||||
id2 = rosetta.core.id.AtomID(2,a+1) # CA-i
|
||||
id3 = rosetta.core.id.AtomID(5,a+1) # CB-i
|
||||
id4 = rosetta.core.id.AtomID(5,b+1) # CB-j
|
||||
rst['theta'].append([a,b,p,rosetta.core.scoring.constraints.DihedralConstraint(id1,id2,id3,id4, spline)])
|
||||
|
||||
print("theta restraints: %d"%(len(rst['theta'])))
|
||||
|
||||
|
||||
########################################################
|
||||
# phi: 0..pi
|
||||
########################################################
|
||||
nbins = phi.shape[2]-1+4
|
||||
bins = np.linspace(-1.5*ASTEP, np.pi+1.5*ASTEP, nbins)
|
||||
x = pyrosetta.rosetta.utility.vector1_double()
|
||||
_ = [x.append(v) for v in bins]
|
||||
prob = np.sum(phi[:,:,1:], axis=-1)
|
||||
np.fill_diagonal(prob, 0.0)
|
||||
i,j = np.where(prob>PCUT+P_ADD_PHI)
|
||||
prob = prob[i,j]
|
||||
phi = -np.log((phi+MEFF)/(phi[:,:,-1]+MEFF)[:,:,None])
|
||||
#if sg_flag == True:
|
||||
# phi = savgol_filter(phi,sg_w,sg_n,axis=-1,mode='mirror')
|
||||
phi = np.concatenate([np.flip(phi[:,:,1:3],axis=-1),phi[:,:,1:],np.flip(phi[:,:,-2:],axis=-1)], axis=-1)
|
||||
for a,b,p in zip(i,j,prob):
|
||||
y = pyrosetta.rosetta.utility.vector1_double()
|
||||
_ = [y.append(v) for v in phi[a,b]]
|
||||
spline = rosetta.core.scoring.func.SplineFunc("", 1.0, 0.0, ASTEP, x,y)
|
||||
id1 = rosetta.core.id.AtomID(2,a+1) # CA-i
|
||||
id2 = rosetta.core.id.AtomID(5,a+1) # CB-i
|
||||
id3 = rosetta.core.id.AtomID(5,b+1) # CB-j
|
||||
rst['phi'].append([a,b,p,rosetta.core.scoring.constraints.AngleConstraint(id1,id2,id3, spline)])
|
||||
print("phi restraints: %d"%(len(rst['phi'])))
|
||||
|
||||
########################################################
|
||||
# backbone torsions
|
||||
########################################################
|
||||
if (params['BB'] != ''):
|
||||
bbnpz = np.load(params['BB'])
|
||||
bbphi,bbpsi = bbnpz['phi'],bbnpz['psi']
|
||||
rst['bbphi'] = []
|
||||
rst['bbpsi'] = []
|
||||
nbins = bbphi.shape[1]+4
|
||||
step = 2.*np.pi/bbphi.shape[1]
|
||||
bins = np.linspace(-1.5*step-np.pi, np.pi+1.5*step, nbins)
|
||||
x = pyrosetta.rosetta.utility.vector1_double()
|
||||
_ = [x.append(v) for v in bins]
|
||||
|
||||
bbphi = -np.log(bbphi)
|
||||
bbphi = np.concatenate([bbphi[:,-2:],bbphi,bbphi[:,:2]],axis=-1).copy()
|
||||
|
||||
bbpsi = -np.log(bbpsi)
|
||||
bbpsi = np.concatenate([bbpsi[:,-2:],bbpsi,bbpsi[:,:2]],axis=-1).copy()
|
||||
|
||||
for i in range(1,nres):
|
||||
N1 = rosetta.core.id.AtomID(1,i)
|
||||
Ca1 = rosetta.core.id.AtomID(2,i)
|
||||
C1 = rosetta.core.id.AtomID(3,i)
|
||||
N2 = rosetta.core.id.AtomID(1,i+1)
|
||||
Ca2 = rosetta.core.id.AtomID(2,i+1)
|
||||
C2 = rosetta.core.id.AtomID(3,i+1)
|
||||
|
||||
# psi(i)
|
||||
ypsi = pyrosetta.rosetta.utility.vector1_double()
|
||||
_ = [ypsi.append(v) for v in bbpsi[i-1]]
|
||||
spsi = rosetta.core.scoring.func.SplineFunc("", BBWGHT, 0.0, step, x,ypsi)
|
||||
rst['bbpsi'].append(rosetta.core.scoring.constraints.DihedralConstraint(N1,Ca1,C1,N2, spsi))
|
||||
|
||||
# phi(i+1)
|
||||
yphi = pyrosetta.rosetta.utility.vector1_double()
|
||||
_ = [yphi.append(v) for v in bbphi[i]]
|
||||
sphi = rosetta.core.scoring.func.SplineFunc("", BBWGHT, 0.0, step, x,yphi)
|
||||
rst['bbphi'].append(rosetta.core.scoring.constraints.DihedralConstraint(C1,N2,Ca2,C2, sphi))
|
||||
|
||||
print("bbbtor restraints: %d"%(len(rst['bbphi'])+len(rst['bbpsi'])))
|
||||
|
||||
return rst
|
||||
|
||||
def set_predicted_dihedral(pose, phi, psi, omega):
|
||||
|
||||
nbins = phi.shape[1]
|
||||
bins = np.linspace(-180.,180.,nbins+1)[:-1] + 180./nbins
|
||||
|
||||
nres = pose.total_residue()
|
||||
for i in range(nres):
|
||||
pose.set_phi(i+1,np.random.choice(bins,p=phi[i]))
|
||||
pose.set_psi(i+1,np.random.choice(bins,p=psi[i]))
|
||||
|
||||
if np.random.uniform() < omega[i,0]:
|
||||
pose.set_omega(i+1,0)
|
||||
else:
|
||||
pose.set_omega(i+1,180)
|
||||
|
||||
def set_random_dihedral(pose):
|
||||
nres = pose.total_residue()
|
||||
for i in range(1, nres+1):
|
||||
phi,psi=random_dihedral()
|
||||
pose.set_phi(i,phi)
|
||||
pose.set_psi(i,psi)
|
||||
pose.set_omega(i,180)
|
||||
|
||||
return(pose)
|
||||
|
||||
|
||||
#pick phi/psi randomly from:
|
||||
#-140 153 180 0.135 B
|
||||
# -72 145 180 0.155 B
|
||||
#-122 117 180 0.073 B
|
||||
# -82 -14 180 0.122 A
|
||||
# -61 -41 180 0.497 A
|
||||
# 57 39 180 0.018 L
|
||||
def random_dihedral():
|
||||
phi=0
|
||||
psi=0
|
||||
r=random.random()
|
||||
if(r<=0.135):
|
||||
phi=-140
|
||||
psi=153
|
||||
elif(r>0.135 and r<=0.29):
|
||||
phi=-72
|
||||
psi=145
|
||||
elif(r>0.29 and r<=0.363):
|
||||
phi=-122
|
||||
psi=117
|
||||
elif(r>0.363 and r<=0.485):
|
||||
phi=-82
|
||||
psi=-14
|
||||
elif(r>0.485 and r<=0.982):
|
||||
phi=-61
|
||||
psi=-41
|
||||
else:
|
||||
phi=57
|
||||
psi=39
|
||||
return(phi, psi)
|
||||
|
||||
|
||||
def read_fasta(file):
|
||||
fasta=""
|
||||
first = True
|
||||
with open(file, "r") as f:
|
||||
for line in f:
|
||||
if(line[0] == ">"):
|
||||
if first:
|
||||
first = False
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
line=line.rstrip()
|
||||
fasta = fasta + line;
|
||||
return fasta
|
||||
|
||||
|
||||
def remove_clash(scorefxn, mover, pose):
|
||||
for _ in range(0, 5):
|
||||
if float(scorefxn(pose)) < 10:
|
||||
break
|
||||
mover.apply(pose)
|
||||
|
||||
|
||||
def add_rst(pose, rst, sep1, sep2, params, nogly=False, use_orient=None, pcut=None, p12_cut=0.0):
|
||||
if use_orient == None:
|
||||
use_orient = params['USE_ORIENT']
|
||||
if pcut == None:
|
||||
pcut=params['PCUT']
|
||||
|
||||
seq = params['seq']
|
||||
|
||||
# collect restraints
|
||||
array = []
|
||||
|
||||
if nogly==True:
|
||||
dist_r = [r for a,b,p,p_12,r in rst['dist'] if abs(a-b)>=sep1 and abs(a-b)<sep2 and seq[a]!='G' and seq[b]!='G' and p>=pcut and p_12>=p12_cut]
|
||||
if use_orient:
|
||||
omega_r = [r for a,b,p,r in rst['omega'] if abs(a-b)>=sep1 and abs(a-b)<sep2 and seq[a]!='G' and seq[b]!='G' and p>=pcut+P_ADD_OMEGA] #0.5
|
||||
theta_r = [r for a,b,p,r in rst['theta'] if abs(a-b)>=sep1 and abs(a-b)<sep2 and seq[a]!='G' and seq[b]!='G' and p>=pcut+P_ADD_THETA] #0.5
|
||||
phi_r = [r for a,b,p,r in rst['phi'] if abs(a-b)>=sep1 and abs(a-b)<sep2 and seq[a]!='G' and seq[b]!='G' and p>=pcut+P_ADD_PHI] #0.6
|
||||
else:
|
||||
dist_r = [r for a,b,p,p_12,r in rst['dist'] if abs(a-b)>=sep1 and abs(a-b)<sep2 and p>=pcut and p_12>=p12_cut]
|
||||
if use_orient:
|
||||
omega_r = [r for a,b,p,r in rst['omega'] if abs(a-b)>=sep1 and abs(a-b)<sep2 and p>=pcut+P_ADD_OMEGA]
|
||||
theta_r = [r for a,b,p,r in rst['theta'] if abs(a-b)>=sep1 and abs(a-b)<sep2 and p>=pcut+P_ADD_THETA]
|
||||
phi_r = [r for a,b,p,r in rst['phi'] if abs(a-b)>=sep1 and abs(a-b)<sep2 and p>=pcut+P_ADD_PHI] #0.6
|
||||
|
||||
#if params['BB'] != '':
|
||||
# array += [r for r in rst['bbphi']]
|
||||
# array += [r for r in rst['bbpsi']]
|
||||
array += dist_r
|
||||
if use_orient:
|
||||
array += omega_r
|
||||
array += theta_r
|
||||
array += phi_r
|
||||
|
||||
if len(array) < 1:
|
||||
return
|
||||
|
||||
print ("Number of applied pair restraints:", len(array))
|
||||
print (" - Distance restraints:", len(dist_r))
|
||||
if use_orient:
|
||||
print (" - Omega restraints:", len(omega_r))
|
||||
print (" - Theta restraints:", len(theta_r))
|
||||
print (" - Phi restraints: ", len(phi_r))
|
||||
|
||||
#random.shuffle(array)
|
||||
|
||||
cset = rosetta.core.scoring.constraints.ConstraintSet()
|
||||
[cset.add_constraint(a) for a in array]
|
||||
|
||||
# add to pose
|
||||
constraints = rosetta.protocols.constraint_movers.ConstraintSetMover()
|
||||
constraints.constraint_set(cset)
|
||||
constraints.add_constraints(True)
|
||||
constraints.apply(pose)
|
||||
|
||||
def add_crd_rst(pose, nres, std=1.0, tol=1.0):
|
||||
flat_har = rosetta.core.scoring.func.FlatHarmonicFunc(0.0, std, tol)
|
||||
rst = list()
|
||||
for i in range(1, nres+1):
|
||||
xyz = pose.residue(i).atom("CA").xyz() # xyz coord of CA atom
|
||||
ida = rosetta.core.id.AtomID(2,i) # CA idx for residue i
|
||||
rst.append(rosetta.core.scoring.constraints.CoordinateConstraint(ida, ida, xyz, flat_har))
|
||||
|
||||
if len(rst) < 1:
|
||||
return
|
||||
|
||||
print ("Number of applied coordinate restraints:", len(rst))
|
||||
#random.shuffle(rst)
|
||||
|
||||
cset = rosetta.core.scoring.constraints.ConstraintSet()
|
||||
[cset.add_constraint(a) for a in rst]
|
||||
|
||||
# add to pose
|
||||
constraints = rosetta.protocols.constraint_movers.ConstraintSetMover()
|
||||
constraints.constraint_set(cset)
|
||||
constraints.add_constraints(True)
|
||||
constraints.apply(pose)
|
||||
|
1
DGLPyTorch/DrugDiscovery/RoseTTAFold/funding.md
Normal file
1
DGLPyTorch/DrugDiscovery/RoseTTAFold/funding.md
Normal file
|
@ -0,0 +1 @@
|
|||
This work was supported by Microsoft (MB, DB, and generous gifts of Azure compute time and expertise), Eric and Wendy Schmidt by recommendation of the Schmidt Futures program (FD, HP), Open Philanthropy (DB, GRL), The Washington Research Foundation (MB, GRL, JW), National Science Foundation Cyberinfrastructure for Biological Research, Award # DBI 1937533 (IA), Wellcome Trust, grant number 209407/Z/17/Z (RJR), National Institute of Health, grant numbers P01GM063210 (PDA, RJR), DP5OD026389 (SO), Global Challenges Research Fund (GCRF) through Science & Technology Facilities Council (STFC), grant number ST/R002754/1: Synchrotron Techniques for African Research and Technology (START) (DJO), Austrian Science Fund (FWF) projects P29432 and DOC50 (doc.fund Molecular Metabolism) (TS, CB, TP).
|
Binary file not shown.
After Width: | Height: | Size: 115 KiB |
57
DGLPyTorch/DrugDiscovery/RoseTTAFold/input_prep/make_msa.sh
Executable file
57
DGLPyTorch/DrugDiscovery/RoseTTAFold/input_prep/make_msa.sh
Executable file
|
@ -0,0 +1,57 @@
|
|||
#!/bin/bash
|
||||
|
||||
# inputs
|
||||
in_fasta="$1"
|
||||
out_dir="$2"
|
||||
|
||||
# resources
|
||||
CPU="$3"
|
||||
MEM="$4"
|
||||
|
||||
# sequence database
|
||||
DB="$5/UniRef30_2020_06/UniRef30_2020_06"
|
||||
|
||||
# setup hhblits command
|
||||
HHBLITS="hhblits -o /dev/null -mact 0.35 -maxfilt 100000000 -neffmax 20 -cov 25 -cpu $CPU -nodiff -realign_max 100000000 -maxseq 1000000 -maxmem $MEM -n 4 -d $DB"
|
||||
echo $HHBLITS
|
||||
|
||||
mkdir -p $out_dir/hhblits
|
||||
tmp_dir="$out_dir/hhblits"
|
||||
out_prefix="$out_dir/t000_"
|
||||
|
||||
# perform iterative searches
|
||||
prev_a3m="$in_fasta"
|
||||
for e in 1e-30 1e-10 1e-6 1e-3
|
||||
do
|
||||
echo $e
|
||||
$HHBLITS -i $prev_a3m -oa3m $tmp_dir/t000_.$e.a3m -e $e -v 0
|
||||
hhfilter -id 90 -cov 75 -i $tmp_dir/t000_.$e.a3m -o $tmp_dir/t000_.$e.id90cov75.a3m
|
||||
hhfilter -id 90 -cov 50 -i $tmp_dir/t000_.$e.a3m -o $tmp_dir/t000_.$e.id90cov50.a3m
|
||||
prev_a3m="$tmp_dir/t000_.$e.id90cov50.a3m"
|
||||
n75=`grep -c "^>" $tmp_dir/t000_.$e.id90cov75.a3m`
|
||||
n50=`grep -c "^>" $tmp_dir/t000_.$e.id90cov50.a3m`
|
||||
|
||||
if ((n75>2000))
|
||||
then
|
||||
if [ ! -s ${out_prefix}.msa0.a3m ]
|
||||
then
|
||||
cp $tmp_dir/t000_.$e.id90cov75.a3m ${out_prefix}.msa0.a3m
|
||||
break
|
||||
fi
|
||||
elif ((n50>4000))
|
||||
then
|
||||
if [ ! -s ${out_prefix}.msa0.a3m ]
|
||||
then
|
||||
cp $tmp_dir/t000_.$e.id90cov50.a3m ${out_prefix}.msa0.a3m
|
||||
break
|
||||
fi
|
||||
else
|
||||
continue
|
||||
fi
|
||||
|
||||
done
|
||||
|
||||
if [ ! -s ${out_prefix}.msa0.a3m ]
|
||||
then
|
||||
cp $tmp_dir/t000_.1e-3.id90cov50.a3m ${out_prefix}.msa0.a3m
|
||||
fi
|
29
DGLPyTorch/DrugDiscovery/RoseTTAFold/input_prep/make_ss.sh
Executable file
29
DGLPyTorch/DrugDiscovery/RoseTTAFold/input_prep/make_ss.sh
Executable file
|
@ -0,0 +1,29 @@
|
|||
#!/bin/bash
|
||||
|
||||
DATADIR="/workspace/psipred/data/"
|
||||
echo $DATADIR
|
||||
|
||||
i_a3m="$1"
|
||||
o_ss="$2"
|
||||
|
||||
ID=$(basename $i_a3m .a3m).tmp
|
||||
|
||||
/workspace/csblast-2.2.3/bin/csbuild -i $i_a3m -I a3m -D /workspace/csblast-2.2.3/data/K4000.crf -o $ID.chk -O chk
|
||||
|
||||
head -n 2 $i_a3m > $ID.fasta
|
||||
echo $ID.chk > $ID.pn
|
||||
echo $ID.fasta > $ID.sn
|
||||
|
||||
/workspace/blast-2.2.26/bin/makemat -P $ID
|
||||
/workspace/psipred/bin/psipred $ID.mtx $DATADIR/weights.dat $DATADIR/weights.dat2 $DATADIR/weights.dat3 > $ID.ss
|
||||
/workspace/psipred/bin/psipass2 $DATADIR/weights_p2.dat 1 1.0 1.0 $i_a3m.csb.hhblits.ss2 $ID.ss > $ID.horiz
|
||||
|
||||
(
|
||||
echo ">ss_pred"
|
||||
grep "^Pred" $ID.horiz | awk '{print $2}'
|
||||
echo ">ss_conf"
|
||||
grep "^Conf" $ID.horiz | awk '{print $2}'
|
||||
) | awk '{if(substr($1,1,1)==">") {print "\n"$1} else {printf "%s", $1}} END {print ""}' | sed "1d" > $o_ss
|
||||
|
||||
rm ${i_a3m}.csb.hhblits.ss2
|
||||
rm $ID.*
|
18
DGLPyTorch/DrugDiscovery/RoseTTAFold/input_prep/prepare_templates.sh
Executable file
18
DGLPyTorch/DrugDiscovery/RoseTTAFold/input_prep/prepare_templates.sh
Executable file
|
@ -0,0 +1,18 @@
|
|||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
WDIR=$1
|
||||
CPU=$2
|
||||
MEM=$3
|
||||
|
||||
DB="$4/pdb100_2021Mar03/pdb100_2021Mar03"
|
||||
HH="hhsearch -b 50 -B 500 -z 50 -Z 500 -mact 0.05 -cpu $CPU -maxmem $MEM -aliw 100000 -e 100 -p 5.0 -d $DB"
|
||||
cat $WDIR/t000_.ss2 $WDIR/t000_.msa0.a3m > $WDIR/t000_.msa0.ss2.a3m
|
||||
$HH -i $WDIR/t000_.msa0.ss2.a3m -o $WDIR/t000_.hhr -atab $WDIR/t000_.atab -v 0
|
||||
|
||||
|
27
DGLPyTorch/DrugDiscovery/RoseTTAFold/install_dependencies.sh
Executable file
27
DGLPyTorch/DrugDiscovery/RoseTTAFold/install_dependencies.sh
Executable file
|
@ -0,0 +1,27 @@
|
|||
#!/bin/bash
|
||||
|
||||
# install external program not supported by conda installation
|
||||
case "$(uname -s)" in
|
||||
Linux*) platform=linux;;
|
||||
Darwin*) platform=macosx;;
|
||||
*) echo "unsupported OS type. exiting"; exit 1
|
||||
esac
|
||||
echo "installing for ${platform}"
|
||||
|
||||
# download lddt
|
||||
echo "downloading lddt . . ."
|
||||
wget https://openstructure.org/static/lddt-${platform}.zip -O lddt.zip
|
||||
unzip -d lddt -j lddt.zip
|
||||
|
||||
# the cs-blast platform descriptoin includes the width of memory addresses
|
||||
# we expect a 64-bit operating system
|
||||
if [[ ${platform} == "linux" ]]; then
|
||||
platform=${platform}64
|
||||
fi
|
||||
|
||||
# download cs-blast
|
||||
echo "downloading cs-blast . . ."
|
||||
wget http://wwwuser.gwdg.de/~compbiol/data/csblast/releases/csblast-2.2.3_${platform}.tar.gz -O csblast-2.2.3.tar.gz
|
||||
mkdir -p csblast-2.2.3
|
||||
tar xf csblast-2.2.3.tar.gz -C csblast-2.2.3 --strip-components=1
|
||||
|
|
@ -0,0 +1,480 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from Transformer import *
|
||||
from Transformer import _get_clones
|
||||
import Transformer
|
||||
from resnet import ResidualNetwork
|
||||
from SE3_network import SE3Transformer
|
||||
from InitStrGenerator import InitStr_Network
|
||||
import dgl
|
||||
# Attention module based on AlphaFold2's idea written by Minkyung Baek
|
||||
# - Iterative MSA feature extraction
|
||||
# - 1) MSA2Pair: extract pairwise feature from MSA --> added to previous residue-pair features
|
||||
# architecture design inspired by CopulaNet paper
|
||||
# - 2) MSA2MSA: process MSA features using Transformer (or Performer) encoder. (Attention over L first followed by attention over N)
|
||||
# - 3) Pair2MSA: Update MSA features using pair feature
|
||||
# - 4) Pair2Pair: process pair features using Transformer (or Performer) encoder.
|
||||
|
||||
def make_graph(xyz, pair, idx, top_k=64, kmin=9):
|
||||
'''
|
||||
Input:
|
||||
- xyz: current backbone cooordinates (B, L, 3, 3)
|
||||
- pair: pair features from Trunk (B, L, L, E)
|
||||
- idx: residue index from ground truth pdb
|
||||
Output:
|
||||
- G: defined graph
|
||||
'''
|
||||
|
||||
B, L = xyz.shape[:2]
|
||||
device = xyz.device
|
||||
|
||||
# distance map from current CA coordinates
|
||||
D = torch.cdist(xyz[:,:,1,:], xyz[:,:,1,:]) + torch.eye(L, device=device).unsqueeze(0)*999.9 # (B, L, L)
|
||||
# seq sep
|
||||
sep = idx[:,None,:] - idx[:,:,None]
|
||||
sep = sep.abs() + torch.eye(L, device=device).unsqueeze(0)*999.9
|
||||
|
||||
# get top_k neighbors
|
||||
D_neigh, E_idx = torch.topk(D, min(top_k, L), largest=False) # shape of E_idx: (B, L, top_k)
|
||||
topk_matrix = torch.zeros((B, L, L), device=device)
|
||||
topk_matrix.scatter_(2, E_idx, 1.0)
|
||||
|
||||
# put an edge if any of the 3 conditions are met:
|
||||
# 1) |i-j| <= kmin (connect sequentially adjacent residues)
|
||||
# 2) top_k neighbors
|
||||
cond = torch.logical_or(topk_matrix > 0.0, sep < kmin)
|
||||
b,i,j = torch.where(cond)
|
||||
|
||||
src = b*L+i
|
||||
tgt = b*L+j
|
||||
G = dgl.graph((src, tgt), num_nodes=B*L).to(device)
|
||||
G.edata['d'] = (xyz[b,j,1,:] - xyz[b,i,1,:]).detach() # no gradient through basis function
|
||||
G.edata['w'] = pair[b,i,j]
|
||||
|
||||
return G
|
||||
|
||||
def get_bonded_neigh(idx):
|
||||
'''
|
||||
Input:
|
||||
- idx: residue indices of given sequence (B,L)
|
||||
Output:
|
||||
- neighbor: bonded neighbor information with sign (B, L, L, 1)
|
||||
'''
|
||||
neighbor = idx[:,None,:] - idx[:,:,None]
|
||||
neighbor = neighbor.float()
|
||||
sign = torch.sign(neighbor) # (B, L, L)
|
||||
neighbor = torch.abs(neighbor)
|
||||
neighbor[neighbor > 1] = 0.0
|
||||
neighbor = sign * neighbor
|
||||
return neighbor.unsqueeze(-1)
|
||||
|
||||
def rbf(D):
|
||||
# Distance radial basis function
|
||||
D_min, D_max, D_count = 0., 20., 36
|
||||
D_mu = torch.linspace(D_min, D_max, D_count).to(D.device)
|
||||
D_mu = D_mu[None,:]
|
||||
D_sigma = (D_max - D_min) / D_count
|
||||
D_expand = torch.unsqueeze(D, -1)
|
||||
RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
|
||||
return RBF
|
||||
|
||||
class CoevolExtractor(nn.Module):
|
||||
def __init__(self, n_feat_proj, n_feat_out, p_drop=0.1):
|
||||
super(CoevolExtractor, self).__init__()
|
||||
|
||||
self.norm_2d = LayerNorm(n_feat_proj*n_feat_proj)
|
||||
# project down to output dimension (pair feature dimension)
|
||||
self.proj_2 = nn.Linear(n_feat_proj**2, n_feat_out)
|
||||
def forward(self, x_down, x_down_w):
|
||||
B, N, L = x_down.shape[:3]
|
||||
|
||||
pair = torch.einsum('abij,ablm->ailjm', x_down, x_down_w) # outer-product & average pool
|
||||
pair = pair.reshape(B, L, L, -1)
|
||||
pair = self.norm_2d(pair)
|
||||
pair = self.proj_2(pair) # (B, L, L, n_feat_out) # project down to pair dimension
|
||||
return pair
|
||||
|
||||
class MSA2Pair(nn.Module):
|
||||
def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32,
|
||||
n_resblock=1, p_drop=0.1, n_att_head=8):
|
||||
super(MSA2Pair, self).__init__()
|
||||
# project down embedding dimension (n_feat --> n_feat_proj)
|
||||
self.norm_1 = LayerNorm(n_feat)
|
||||
self.proj_1 = nn.Linear(n_feat, n_feat_proj)
|
||||
|
||||
self.encoder = SequenceWeight(n_feat_proj, 1, dropout=p_drop)
|
||||
self.coevol = CoevolExtractor(n_feat_proj, n_feat_out)
|
||||
|
||||
# ResNet to update pair features
|
||||
self.norm_down = LayerNorm(n_feat_proj)
|
||||
self.norm_orig = LayerNorm(n_feat_out)
|
||||
self.norm_new = LayerNorm(n_feat_out)
|
||||
self.update = ResidualNetwork(n_resblock, n_feat_out*2+n_feat_proj*4+n_att_head, n_feat_out, n_feat_out, p_drop=p_drop)
|
||||
|
||||
def forward(self, msa, pair_orig, att):
|
||||
# Input: MSA embeddings (B, N, L, K), original pair embeddings (B, L, L, C)
|
||||
# Output: updated pair info (B, L, L, C)
|
||||
B, N, L, _ = msa.shape
|
||||
# project down to reduce memory
|
||||
msa = self.norm_1(msa)
|
||||
x_down = self.proj_1(msa) # (B, N, L, n_feat_proj)
|
||||
|
||||
# get sequence weight
|
||||
x_down = self.norm_down(x_down)
|
||||
w_seq = self.encoder(x_down).reshape(B, L, 1, N).permute(0,3,1,2)
|
||||
feat_1d = w_seq*x_down
|
||||
|
||||
pair = self.coevol(x_down, feat_1d)
|
||||
|
||||
# average pooling over N of given MSA info
|
||||
feat_1d = feat_1d.sum(1)
|
||||
|
||||
# query sequence info
|
||||
query = x_down[:,0] # (B,L,K)
|
||||
feat_1d = torch.cat((feat_1d, query), dim=-1) # additional 1D features
|
||||
# tile 1D features
|
||||
left = feat_1d.unsqueeze(2).repeat(1, 1, L, 1)
|
||||
right = feat_1d.unsqueeze(1).repeat(1, L, 1, 1)
|
||||
# update original pair features through convolutions after concat
|
||||
pair_orig = self.norm_orig(pair_orig)
|
||||
pair = self.norm_new(pair)
|
||||
pair = torch.cat((pair_orig, pair, left, right, att), -1)
|
||||
pair = pair.permute(0,3,1,2).contiguous() # prep for convolution layer
|
||||
pair = self.update(pair)
|
||||
pair = pair.permute(0,2,3,1).contiguous() # (B, L, L, C)
|
||||
|
||||
return pair
|
||||
|
||||
class MSA2MSA(nn.Module):
|
||||
def __init__(self, n_layer=1, n_att_head=8, n_feat=256, r_ff=4, p_drop=0.1,
|
||||
performer_N_opts=None, performer_L_opts=None):
|
||||
super(MSA2MSA, self).__init__()
|
||||
# attention along L
|
||||
enc_layer_1 = EncoderLayer(d_model=n_feat, d_ff=n_feat*r_ff,
|
||||
heads=n_att_head, p_drop=p_drop,
|
||||
use_tied=True)
|
||||
#performer_opts=performer_L_opts)
|
||||
self.encoder_1 = Encoder(enc_layer_1, n_layer)
|
||||
|
||||
# attention along N
|
||||
enc_layer_2 = EncoderLayer(d_model=n_feat, d_ff=n_feat*r_ff,
|
||||
heads=n_att_head, p_drop=p_drop,
|
||||
performer_opts=performer_N_opts)
|
||||
self.encoder_2 = Encoder(enc_layer_2, n_layer)
|
||||
|
||||
def forward(self, x):
|
||||
# Input: MSA embeddings (B, N, L, K)
|
||||
# Output: updated MSA embeddings (B, N, L, K)
|
||||
B, N, L, _ = x.shape
|
||||
# attention along L
|
||||
x, att = self.encoder_1(x, return_att=True)
|
||||
# attention along N
|
||||
x = x.permute(0,2,1,3).contiguous()
|
||||
x = self.encoder_2(x)
|
||||
x = x.permute(0,2,1,3).contiguous()
|
||||
return x, att
|
||||
|
||||
class Pair2MSA(nn.Module):
|
||||
def __init__(self, n_layer=1, n_att_head=4, n_feat_in=128, n_feat_out=256, r_ff=4, p_drop=0.1):
|
||||
super(Pair2MSA, self).__init__()
|
||||
enc_layer = DirectEncoderLayer(heads=n_att_head, \
|
||||
d_in=n_feat_in, d_out=n_feat_out,\
|
||||
d_ff=n_feat_out*r_ff,\
|
||||
p_drop=p_drop)
|
||||
self.encoder = CrossEncoder(enc_layer, n_layer)
|
||||
|
||||
def forward(self, pair, msa):
|
||||
out = self.encoder(pair, msa) # (B, N, L, K)
|
||||
return out
|
||||
|
||||
class Pair2Pair(nn.Module):
|
||||
def __init__(self, n_layer=1, n_att_head=8, n_feat=128, r_ff=4, p_drop=0.1,
|
||||
performer_L_opts=None):
|
||||
super(Pair2Pair, self).__init__()
|
||||
enc_layer = AxialEncoderLayer(d_model=n_feat, d_ff=n_feat*r_ff,
|
||||
heads=n_att_head, p_drop=p_drop,
|
||||
performer_opts=performer_L_opts)
|
||||
self.encoder = Encoder(enc_layer, n_layer)
|
||||
|
||||
def forward(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
class Str2Str(nn.Module):
|
||||
def __init__(self, d_msa=64, d_pair=128,
|
||||
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, p_drop=0.1):
|
||||
super(Str2Str, self).__init__()
|
||||
|
||||
# initial node & pair feature process
|
||||
self.norm_msa = LayerNorm(d_msa)
|
||||
self.norm_pair = LayerNorm(d_pair)
|
||||
self.encoder_seq = SequenceWeight(d_msa, 1, dropout=p_drop)
|
||||
|
||||
self.embed_x = nn.Linear(d_msa+21, SE3_param['l0_in_features'])
|
||||
self.embed_e = nn.Linear(d_pair, SE3_param['num_edge_features'])
|
||||
|
||||
self.norm_node = LayerNorm(SE3_param['l0_in_features'])
|
||||
self.norm_edge = LayerNorm(SE3_param['num_edge_features'])
|
||||
|
||||
self.se3 = SE3Transformer(**SE3_param)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def forward(self, msa, pair, xyz, seq1hot, idx, top_k=64):
|
||||
# process msa & pair features
|
||||
B, N, L = msa.shape[:3]
|
||||
msa = self.norm_msa(msa)
|
||||
pair = self.norm_pair(pair)
|
||||
|
||||
w_seq = self.encoder_seq(msa).reshape(B, L, 1, N).permute(0,3,1,2)
|
||||
msa = w_seq*msa
|
||||
msa = msa.sum(dim=1)
|
||||
msa = torch.cat((msa, seq1hot), dim=-1)
|
||||
msa = self.norm_node(self.embed_x(msa))
|
||||
pair = self.norm_edge(self.embed_e(pair))
|
||||
|
||||
# define graph
|
||||
G = make_graph(xyz, pair, idx, top_k=top_k)
|
||||
l1_feats = xyz - xyz[:,:,1,:].unsqueeze(2) # l1 features = displacement vector to CA
|
||||
l1_feats = l1_feats.reshape(B*L, -1, 3)
|
||||
|
||||
# apply SE(3) Transformer & update coordinates
|
||||
shift = self.se3(G, msa.reshape(B*L, -1, 1), l1_feats)
|
||||
|
||||
state = shift['0'].reshape(B, L, -1) # (B, L, C)
|
||||
|
||||
offset = shift['1'].reshape(B, L, -1, 3) # (B, L, 3, 3)
|
||||
CA_new = xyz[:,:,1] + offset[:,:,1]
|
||||
N_new = CA_new + offset[:,:,0]
|
||||
C_new = CA_new + offset[:,:,2]
|
||||
xyz_new = torch.stack([N_new, CA_new, C_new], dim=2)
|
||||
|
||||
return xyz_new, state
|
||||
|
||||
class Str2MSA(nn.Module):
|
||||
def __init__(self, d_msa=64, d_state=32, inner_dim=32, r_ff=4,
|
||||
distbin=[8.0, 12.0, 16.0, 20.0], p_drop=0.1):
|
||||
super(Str2MSA, self).__init__()
|
||||
self.distbin = distbin
|
||||
n_att_head = len(distbin)
|
||||
|
||||
self.norm_state = LayerNorm(d_state)
|
||||
self.norm1 = LayerNorm(d_msa)
|
||||
self.attn = MaskedDirectMultiheadAttention(d_state, d_msa, n_att_head, d_k=inner_dim, dropout=p_drop)
|
||||
self.dropout1 = nn.Dropout(p_drop,inplace=True)
|
||||
|
||||
self.norm2 = LayerNorm(d_msa)
|
||||
self.ff = FeedForwardLayer(d_msa, d_msa*r_ff, p_drop=p_drop)
|
||||
self.dropout2 = nn.Dropout(p_drop,inplace=True)
|
||||
|
||||
def forward(self, msa, xyz, state):
|
||||
dist = torch.cdist(xyz[:,:,1], xyz[:,:,1]) # (B, L, L)
|
||||
|
||||
mask_s = list()
|
||||
for distbin in self.distbin:
|
||||
mask_s.append(1.0 - torch.sigmoid(dist-distbin))
|
||||
mask_s = torch.stack(mask_s, dim=1) # (B, h, L, L)
|
||||
|
||||
state = self.norm_state(state)
|
||||
msa2 = self.norm1(msa)
|
||||
msa2 = self.attn(state, state, msa2, mask_s)
|
||||
msa = msa + self.dropout1(msa2)
|
||||
|
||||
msa2 = self.norm2(msa)
|
||||
msa2 = self.ff(msa2)
|
||||
msa = msa + self.dropout2(msa2)
|
||||
|
||||
return msa
|
||||
|
||||
class IterBlock(nn.Module):
|
||||
def __init__(self, n_layer=1, d_msa=64, d_pair=128, n_head_msa=4, n_head_pair=8, r_ff=4,
|
||||
n_resblock=1, p_drop=0.1, performer_L_opts=None, performer_N_opts=None):
|
||||
super(IterBlock, self).__init__()
|
||||
|
||||
self.msa2msa = MSA2MSA(n_layer=n_layer, n_att_head=n_head_msa, n_feat=d_msa,
|
||||
r_ff=r_ff, p_drop=p_drop,
|
||||
performer_N_opts=performer_N_opts,
|
||||
performer_L_opts=performer_L_opts)
|
||||
self.msa2pair = MSA2Pair(n_feat=d_msa, n_feat_out=d_pair, n_feat_proj=32,
|
||||
n_resblock=n_resblock, p_drop=p_drop, n_att_head=n_head_msa)
|
||||
self.pair2pair = Pair2Pair(n_layer=n_layer, n_att_head=n_head_pair,
|
||||
n_feat=d_pair, r_ff=r_ff, p_drop=p_drop,
|
||||
performer_L_opts=performer_L_opts)
|
||||
self.pair2msa = Pair2MSA(n_layer=n_layer, n_att_head=4,
|
||||
n_feat_in=d_pair, n_feat_out=d_msa, r_ff=r_ff, p_drop=p_drop)
|
||||
|
||||
def forward(self, msa, pair):
|
||||
# input:
|
||||
# msa: initial MSA embeddings (N, L, d_msa)
|
||||
# pair: initial residue pair embeddings (L, L, d_pair)
|
||||
|
||||
# 1. process MSA features
|
||||
msa, att = self.msa2msa(msa)
|
||||
|
||||
# 2. update pair features using given MSA
|
||||
pair = self.msa2pair(msa, pair, att)
|
||||
|
||||
# 3. process pair features
|
||||
pair = self.pair2pair(pair)
|
||||
|
||||
# 4. update MSA features using updated pair features
|
||||
msa = self.pair2msa(pair, msa)
|
||||
|
||||
|
||||
return msa, pair
|
||||
|
||||
class IterBlock_w_Str(nn.Module):
|
||||
def __init__(self, n_layer=1, d_msa=64, d_pair=128, n_head_msa=4, n_head_pair=8, r_ff=4,
|
||||
n_resblock=1, p_drop=0.1, performer_L_opts=None, performer_N_opts=None,
|
||||
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}):
|
||||
super(IterBlock_w_Str, self).__init__()
|
||||
|
||||
self.msa2msa = MSA2MSA(n_layer=n_layer, n_att_head=n_head_msa, n_feat=d_msa,
|
||||
r_ff=r_ff, p_drop=p_drop,
|
||||
performer_N_opts=performer_N_opts,
|
||||
performer_L_opts=performer_L_opts)
|
||||
self.msa2pair = MSA2Pair(n_feat=d_msa, n_feat_out=d_pair, n_feat_proj=32,
|
||||
n_resblock=n_resblock, p_drop=p_drop, n_att_head=n_head_msa)
|
||||
self.pair2pair = Pair2Pair(n_layer=n_layer, n_att_head=n_head_pair,
|
||||
n_feat=d_pair, r_ff=r_ff, p_drop=p_drop,
|
||||
performer_L_opts=performer_L_opts)
|
||||
self.pair2msa = Pair2MSA(n_layer=n_layer, n_att_head=4,
|
||||
n_feat_in=d_pair, n_feat_out=d_msa, r_ff=r_ff, p_drop=p_drop)
|
||||
self.str2str = Str2Str(d_msa=d_msa, d_pair=d_pair, SE3_param=SE3_param, p_drop=p_drop)
|
||||
self.str2msa = Str2MSA(d_msa=d_msa, d_state=SE3_param['l0_out_features'],
|
||||
r_ff=r_ff, p_drop=p_drop)
|
||||
|
||||
def forward(self, msa, pair, xyz, seq1hot, idx, top_k=64):
|
||||
# input:
|
||||
# msa: initial MSA embeddings (N, L, d_msa)
|
||||
# pair: initial residue pair embeddings (L, L, d_pair)
|
||||
|
||||
# 1. process MSA features
|
||||
msa, att = self.msa2msa(msa)
|
||||
|
||||
# 2. update pair features using given MSA
|
||||
pair = self.msa2pair(msa, pair, att)
|
||||
|
||||
# 3. process pair features
|
||||
pair = self.pair2pair(pair)
|
||||
|
||||
# 4. update MSA features using updated pair features
|
||||
msa = self.pair2msa(pair, msa)
|
||||
|
||||
xyz, state = self.str2str(msa.float(), pair.float(), xyz.float(), seq1hot, idx, top_k=top_k)
|
||||
msa = self.str2msa(msa, xyz, state)
|
||||
|
||||
return msa, pair, xyz
|
||||
|
||||
class FinalBlock(nn.Module):
|
||||
def __init__(self, n_layer=1, d_msa=64, d_pair=128, n_head_msa=4, n_head_pair=8, r_ff=4,
|
||||
n_resblock=1, p_drop=0.1, performer_L_opts=None, performer_N_opts=None,
|
||||
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}):
|
||||
super(FinalBlock, self).__init__()
|
||||
|
||||
self.msa2msa = MSA2MSA(n_layer=n_layer, n_att_head=n_head_msa, n_feat=d_msa,
|
||||
r_ff=r_ff, p_drop=p_drop,
|
||||
performer_N_opts=performer_N_opts,
|
||||
performer_L_opts=performer_L_opts)
|
||||
self.msa2pair = MSA2Pair(n_feat=d_msa, n_feat_out=d_pair, n_feat_proj=32,
|
||||
n_resblock=n_resblock, p_drop=p_drop, n_att_head=n_head_msa)
|
||||
self.pair2pair = Pair2Pair(n_layer=n_layer, n_att_head=n_head_pair,
|
||||
n_feat=d_pair, r_ff=r_ff, p_drop=p_drop,
|
||||
performer_L_opts=performer_L_opts)
|
||||
self.pair2msa = Pair2MSA(n_layer=n_layer, n_att_head=4,
|
||||
n_feat_in=d_pair, n_feat_out=d_msa, r_ff=r_ff, p_drop=p_drop)
|
||||
self.str2str = Str2Str(d_msa=d_msa, d_pair=d_pair, SE3_param=SE3_param, p_drop=p_drop)
|
||||
self.norm_state = LayerNorm(SE3_param['l0_out_features'])
|
||||
self.pred_lddt = nn.Linear(SE3_param['l0_out_features'], 1)
|
||||
|
||||
def forward(self, msa, pair, xyz, seq1hot, idx):
|
||||
# input:
|
||||
# msa: initial MSA embeddings (N, L, d_msa)
|
||||
# pair: initial residue pair embeddings (L, L, d_pair)
|
||||
|
||||
# 1. process MSA features
|
||||
msa, att = self.msa2msa(msa)
|
||||
|
||||
# 2. update pair features using given MSA
|
||||
pair = self.msa2pair(msa, pair, att)
|
||||
|
||||
# 3. process pair features
|
||||
pair = self.pair2pair(pair)
|
||||
|
||||
msa = self.pair2msa(pair, msa)
|
||||
|
||||
xyz, state = self.str2str(msa.float(), pair.float(), xyz.float(), seq1hot, idx, top_k=32)
|
||||
|
||||
lddt = self.pred_lddt(self.norm_state(state))
|
||||
return msa, pair, xyz, lddt.squeeze(-1)
|
||||
|
||||
class IterativeFeatureExtractor(nn.Module):
|
||||
def __init__(self, n_module=4, n_module_str=4, n_layer=4, d_msa=256, d_pair=128, d_hidden=64,
|
||||
n_head_msa=8, n_head_pair=8, r_ff=4,
|
||||
n_resblock=1, p_drop=0.1,
|
||||
performer_L_opts=None, performer_N_opts=None,
|
||||
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}):
|
||||
super(IterativeFeatureExtractor, self).__init__()
|
||||
self.n_module = n_module
|
||||
self.n_module_str = n_module_str
|
||||
#
|
||||
self.initial = Pair2Pair(n_layer=n_layer, n_att_head=n_head_pair,
|
||||
n_feat=d_pair, r_ff=r_ff, p_drop=p_drop,
|
||||
performer_L_opts=performer_L_opts)
|
||||
|
||||
if self.n_module > 0:
|
||||
self.iter_block_1 = _get_clones(IterBlock(n_layer=n_layer,
|
||||
d_msa=d_msa, d_pair=d_pair,
|
||||
n_head_msa=n_head_msa,
|
||||
n_head_pair=n_head_pair,
|
||||
r_ff=r_ff,
|
||||
n_resblock=n_resblock,
|
||||
p_drop=p_drop,
|
||||
performer_N_opts=performer_N_opts,
|
||||
performer_L_opts=performer_L_opts
|
||||
), n_module)
|
||||
|
||||
self.init_str = InitStr_Network(node_dim_in=d_msa, node_dim_hidden=d_hidden,
|
||||
edge_dim_in=d_pair, edge_dim_hidden=d_hidden,
|
||||
nheads=4, nblocks=3, dropout=p_drop)
|
||||
|
||||
if self.n_module_str > 0:
|
||||
self.iter_block_2 = _get_clones(IterBlock_w_Str(n_layer=n_layer,
|
||||
d_msa=d_msa, d_pair=d_pair,
|
||||
n_head_msa=n_head_msa,
|
||||
n_head_pair=n_head_pair,
|
||||
r_ff=r_ff,
|
||||
n_resblock=n_resblock,
|
||||
p_drop=p_drop,
|
||||
performer_N_opts=performer_N_opts,
|
||||
performer_L_opts=performer_L_opts,
|
||||
SE3_param=SE3_param
|
||||
), n_module_str)
|
||||
|
||||
self.final = FinalBlock(n_layer=n_layer, d_msa=d_msa, d_pair=d_pair,
|
||||
n_head_msa=n_head_msa, n_head_pair=n_head_pair, r_ff=r_ff,
|
||||
n_resblock=n_resblock, p_drop=p_drop,
|
||||
performer_L_opts=performer_L_opts, performer_N_opts=performer_N_opts,
|
||||
SE3_param=SE3_param)
|
||||
|
||||
def forward(self, msa, pair, seq1hot, idx):
|
||||
# input:
|
||||
# msa: initial MSA embeddings (N, L, d_msa)
|
||||
# pair: initial residue pair embeddings (L, L, d_pair)
|
||||
|
||||
pair_s = list()
|
||||
pair = self.initial(pair)
|
||||
if self.n_module > 0:
|
||||
for i_m in range(self.n_module):
|
||||
# extract features from MSA & update original pair features
|
||||
msa, pair = self.iter_block_1[i_m](msa, pair)
|
||||
|
||||
xyz = self.init_str(seq1hot, idx, msa, pair)
|
||||
|
||||
top_ks = [128, 128, 64, 64]
|
||||
if self.n_module_str > 0:
|
||||
for i_m in range(self.n_module_str):
|
||||
msa, pair, xyz = self.iter_block_2[i_m](msa, pair, xyz, seq1hot, idx, top_k=top_ks[i_m])
|
||||
|
||||
msa, pair, xyz, lddt = self.final(msa, pair, xyz, seq1hot, idx)
|
||||
|
||||
return msa[:,0], pair, xyz, lddt
|
|
@ -0,0 +1,36 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from resnet import ResidualNetwork
|
||||
from Transformer import LayerNorm
|
||||
|
||||
# predict distance map from pair features
|
||||
# based on simple 2D ResNet
|
||||
|
||||
class DistanceNetwork(nn.Module):
|
||||
def __init__(self, n_feat, n_block=1, block_type='orig', p_drop=0.0):
|
||||
super(DistanceNetwork, self).__init__()
|
||||
self.norm = LayerNorm(n_feat)
|
||||
self.proj = nn.Linear(n_feat, n_feat)
|
||||
self.drop = nn.Dropout(p_drop)
|
||||
#
|
||||
self.resnet_dist = ResidualNetwork(n_block, n_feat, n_feat, 37, block_type=block_type, p_drop=p_drop)
|
||||
self.resnet_omega = ResidualNetwork(n_block, n_feat, n_feat, 37, block_type=block_type, p_drop=p_drop)
|
||||
self.resnet_theta = ResidualNetwork(n_block, n_feat, n_feat, 37, block_type=block_type, p_drop=p_drop)
|
||||
self.resnet_phi = ResidualNetwork(n_block, n_feat, n_feat, 19, block_type=block_type, p_drop=p_drop)
|
||||
|
||||
def forward(self, x):
|
||||
# input: pair info (B, L, L, C)
|
||||
x = self.norm(x)
|
||||
x = self.drop(self.proj(x))
|
||||
x = x.permute(0,3,1,2).contiguous()
|
||||
|
||||
# predict theta, phi (non-symmetric)
|
||||
logits_theta = self.resnet_theta(x)
|
||||
logits_phi = self.resnet_phi(x)
|
||||
|
||||
# predict dist, omega
|
||||
x = 0.5 * (x + x.permute(0,1,3,2))
|
||||
logits_dist = self.resnet_dist(x)
|
||||
logits_omega = self.resnet_omega(x)
|
||||
|
||||
return logits_dist, logits_omega, logits_theta, logits_phi
|
175
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/Embeddings.py
Normal file
175
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/Embeddings.py
Normal file
|
@ -0,0 +1,175 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from Transformer import EncoderLayer, AxialEncoderLayer, Encoder, LayerNorm
|
||||
|
||||
# Initial embeddings for target sequence, msa, template info
|
||||
# positional encoding
|
||||
# option 1: using sin/cos --> using this for now
|
||||
# option 2: learn positional embedding
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model, p_drop=0.1, max_len=5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.drop = nn.Dropout(p_drop,inplace=True)
|
||||
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2) *
|
||||
-(math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.register_buffer('pe', pe) # (1, max_len, d_model)
|
||||
def forward(self, x, idx_s):
|
||||
pe = list()
|
||||
for idx in idx_s:
|
||||
pe.append(self.pe[:,idx,:])
|
||||
pe = torch.stack(pe)
|
||||
x = x + torch.autograd.Variable(pe, requires_grad=False)
|
||||
return self.drop(x)
|
||||
|
||||
class PositionalEncoding2D(nn.Module):
|
||||
def __init__(self, d_model, p_drop=0.1):
|
||||
super(PositionalEncoding2D, self).__init__()
|
||||
self.drop = nn.Dropout(p_drop,inplace=True)
|
||||
#
|
||||
d_model_half = d_model // 2
|
||||
div_term = torch.exp(torch.arange(0., d_model_half, 2) *
|
||||
-(math.log(10000.0) / d_model_half))
|
||||
self.register_buffer('div_term', div_term)
|
||||
|
||||
def forward(self, x, idx_s):
|
||||
B, L, _, K = x.shape
|
||||
K_half = K//2
|
||||
pe = torch.zeros_like(x)
|
||||
i_batch = -1
|
||||
for idx in idx_s:
|
||||
i_batch += 1
|
||||
sin_inp = idx.unsqueeze(1) * self.div_term
|
||||
emb = torch.cat((sin_inp.sin(), sin_inp.cos()), dim=-1) # (L, K//2)
|
||||
pe[i_batch,:,:,:K_half] = emb.unsqueeze(1)
|
||||
pe[i_batch,:,:,K_half:] = emb.unsqueeze(0)
|
||||
x = x + torch.autograd.Variable(pe, requires_grad=False)
|
||||
return self.drop(x)
|
||||
|
||||
class QueryEncoding(nn.Module):
|
||||
def __init__(self, d_model):
|
||||
super(QueryEncoding, self).__init__()
|
||||
self.pe = nn.Embedding(2, d_model) # (0 for query, 1 for others)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, L, K = x.shape
|
||||
idx = torch.ones((B, N, L), device=x.device).long()
|
||||
idx[:,0,:] = 0 # first sequence is the query
|
||||
x = x + self.pe(idx)
|
||||
return x
|
||||
|
||||
class MSA_emb(nn.Module):
|
||||
def __init__(self, d_model=64, d_msa=21, p_drop=0.1, max_len=5000):
|
||||
super(MSA_emb, self).__init__()
|
||||
self.emb = nn.Embedding(d_msa, d_model)
|
||||
self.pos = PositionalEncoding(d_model, p_drop=p_drop, max_len=max_len)
|
||||
self.pos_q = QueryEncoding(d_model)
|
||||
def forward(self, msa, idx):
|
||||
B, N, L = msa.shape
|
||||
out = self.emb(msa) # (B, N, L, K//2)
|
||||
out = self.pos(out, idx) # add positional encoding
|
||||
return self.pos_q(out) # add query encoding
|
||||
|
||||
# pixel-wise attention based embedding (from trRosetta-tbm)
|
||||
class Templ_emb(nn.Module):
|
||||
def __init__(self, d_t1d=3, d_t2d=10, d_templ=64, n_att_head=4, r_ff=4,
|
||||
performer_opts=None, p_drop=0.1, max_len=5000):
|
||||
super(Templ_emb, self).__init__()
|
||||
self.proj = nn.Linear(d_t1d*2+d_t2d+1, d_templ)
|
||||
self.pos = PositionalEncoding2D(d_templ, p_drop=p_drop)
|
||||
# attention along L
|
||||
enc_layer_L = AxialEncoderLayer(d_templ, d_templ*r_ff, n_att_head, p_drop=p_drop,
|
||||
performer_opts=performer_opts)
|
||||
self.encoder_L = Encoder(enc_layer_L, 1)
|
||||
|
||||
self.norm = LayerNorm(d_templ)
|
||||
self.to_attn = nn.Linear(d_templ, 1)
|
||||
|
||||
def forward(self, t1d, t2d, idx):
|
||||
# Input
|
||||
# - t1d: 1D template info (B, T, L, 2)
|
||||
# - t2d: 2D template info (B, T, L, L, 10)
|
||||
B, T, L, _ = t1d.shape
|
||||
left = t1d.unsqueeze(3).expand(-1,-1,-1,L,-1)
|
||||
right = t1d.unsqueeze(2).expand(-1,-1,L,-1,-1)
|
||||
seqsep = torch.abs(idx[:,:,None]-idx[:,None,:]) + 1
|
||||
seqsep = torch.log(seqsep.float()).view(B,L,L,1).unsqueeze(1).expand(-1,T,-1,-1,-1)
|
||||
#
|
||||
feat = torch.cat((t2d, left, right, seqsep), -1)
|
||||
feat = self.proj(feat).reshape(B*T, L, L, -1)
|
||||
tmp = self.pos(feat, idx) # add positional embedding
|
||||
#
|
||||
# attention along L
|
||||
feat = torch.empty_like(tmp)
|
||||
for i_f in range(tmp.shape[0]):
|
||||
feat[i_f] = self.encoder_L(tmp[i_f].view(1,L,L,-1))
|
||||
del tmp
|
||||
feat = feat.reshape(B, T, L, L, -1)
|
||||
feat = feat.permute(0,2,3,1,4).contiguous().reshape(B, L*L, T, -1)
|
||||
|
||||
attn = self.to_attn(self.norm(feat))
|
||||
attn = F.softmax(attn, dim=-2) # (B, L*L, T, 1)
|
||||
feat = torch.matmul(attn.transpose(-2, -1), feat)
|
||||
return feat.reshape(B, L, L, -1)
|
||||
|
||||
class Pair_emb_w_templ(nn.Module):
|
||||
def __init__(self, d_model=128, d_seq=21, d_templ=64, p_drop=0.1):
|
||||
super(Pair_emb_w_templ, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.d_emb = d_model // 2
|
||||
self.emb = nn.Embedding(d_seq, self.d_emb)
|
||||
self.norm_templ = LayerNorm(d_templ)
|
||||
self.projection = nn.Linear(d_model + d_templ + 1, d_model)
|
||||
self.pos = PositionalEncoding2D(d_model, p_drop=p_drop)
|
||||
|
||||
def forward(self, seq, idx, templ):
|
||||
# input:
|
||||
# seq: target sequence (B, L, 20)
|
||||
B = seq.shape[0]
|
||||
L = seq.shape[1]
|
||||
#
|
||||
# get initial sequence pair features
|
||||
seq = self.emb(seq) # (B, L, d_model//2)
|
||||
left = seq.unsqueeze(2).expand(-1,-1,L,-1)
|
||||
right = seq.unsqueeze(1).expand(-1,L,-1,-1)
|
||||
seqsep = torch.abs(idx[:,:,None]-idx[:,None,:])+1
|
||||
seqsep = torch.log(seqsep.float()).view(B,L,L,1)
|
||||
#
|
||||
templ = self.norm_templ(templ)
|
||||
pair = torch.cat((left, right, seqsep, templ), dim=-1)
|
||||
pair = self.projection(pair) # (B, L, L, d_model)
|
||||
|
||||
return self.pos(pair, idx)
|
||||
|
||||
class Pair_emb_wo_templ(nn.Module):
|
||||
#TODO: embedding without template info
|
||||
def __init__(self, d_model=128, d_seq=21, p_drop=0.1):
|
||||
super(Pair_emb_wo_templ, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.d_emb = d_model // 2
|
||||
self.emb = nn.Embedding(d_seq, self.d_emb)
|
||||
self.projection = nn.Linear(d_model + 1, d_model)
|
||||
self.pos = PositionalEncoding2D(d_model, p_drop=p_drop)
|
||||
def forward(self, seq, idx):
|
||||
# input:
|
||||
# seq: target sequence (B, L, 20)
|
||||
B = seq.shape[0]
|
||||
L = seq.shape[1]
|
||||
seq = self.emb(seq) # (B, L, d_model//2)
|
||||
left = seq.unsqueeze(2).expand(-1,-1,L,-1)
|
||||
right = seq.unsqueeze(1).expand(-1,L,-1,-1)
|
||||
seqsep = torch.abs(idx[:,:,None]-idx[:,None,:])+1
|
||||
seqsep = torch.log(seqsep.float()).view(B,L,L,1)
|
||||
#
|
||||
pair = torch.cat((left, right, seqsep), dim=-1)
|
||||
pair = self.projection(pair)
|
||||
return self.pos(pair, idx)
|
||||
|
116
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/InitStrGenerator.py
Normal file
116
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/InitStrGenerator.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from Transformer import LayerNorm, SequenceWeight
|
||||
|
||||
import torch_geometric
|
||||
from torch_geometric.data import Data
|
||||
from torch_geometric.nn import TransformerConv
|
||||
|
||||
def get_seqsep(idx):
|
||||
'''
|
||||
Input:
|
||||
- idx: residue indices of given sequence (B,L)
|
||||
Output:
|
||||
- seqsep: sequence separation feature with sign (B, L, L, 1)
|
||||
Sergey found that having sign in seqsep features helps a little
|
||||
'''
|
||||
seqsep = idx[:,None,:] - idx[:,:,None]
|
||||
sign = torch.sign(seqsep)
|
||||
seqsep = torch.log(torch.abs(seqsep) + 1.0)
|
||||
seqsep = torch.clamp(seqsep, 0.0, 5.5)
|
||||
seqsep = sign * seqsep
|
||||
return seqsep.unsqueeze(-1)
|
||||
|
||||
def make_graph(node, idx, emb):
|
||||
''' create torch_geometric graph from Trunk outputs '''
|
||||
device = emb.device
|
||||
B, L = emb.shape[:2]
|
||||
|
||||
# |i-j| <= kmin (connect sequentially adjacent residues)
|
||||
sep = idx[:,None,:] - idx[:,:,None]
|
||||
sep = sep.abs()
|
||||
b, i, j = torch.where(sep > 0)
|
||||
|
||||
src = b*L+i
|
||||
tgt = b*L+j
|
||||
|
||||
x = node.reshape(B*L, -1)
|
||||
|
||||
G = Data(x=x,
|
||||
edge_index=torch.stack([src,tgt]),
|
||||
edge_attr=emb[b,i,j])
|
||||
|
||||
return G
|
||||
|
||||
class UniMPBlock(nn.Module):
|
||||
'''https://arxiv.org/pdf/2009.03509.pdf'''
|
||||
def __init__(self,
|
||||
node_dim=64,
|
||||
edge_dim=64,
|
||||
heads=4,
|
||||
dropout=0.15):
|
||||
super(UniMPBlock, self).__init__()
|
||||
|
||||
self.TConv = TransformerConv(node_dim, node_dim, heads, dropout=dropout, edge_dim=edge_dim)
|
||||
self.LNorm = LayerNorm(node_dim*heads)
|
||||
self.Linear = nn.Linear(node_dim*heads, node_dim)
|
||||
self.Activ = nn.ELU(inplace=True)
|
||||
|
||||
#@torch.cuda.amp.autocast(enabled=True)
|
||||
def forward(self, G):
|
||||
xin, e_idx, e_attr = G.x, G.edge_index, G.edge_attr
|
||||
x = self.TConv(xin, e_idx, e_attr)
|
||||
x = self.LNorm(x)
|
||||
x = self.Linear(x)
|
||||
out = self.Activ(x+xin)
|
||||
return Data(x=out, edge_index=e_idx, edge_attr=e_attr)
|
||||
|
||||
|
||||
class InitStr_Network(nn.Module):
|
||||
def __init__(self,
|
||||
node_dim_in=64,
|
||||
node_dim_hidden=64,
|
||||
edge_dim_in=128,
|
||||
edge_dim_hidden=64,
|
||||
nheads=4,
|
||||
nblocks=3,
|
||||
dropout=0.1):
|
||||
super(InitStr_Network, self).__init__()
|
||||
|
||||
# embedding layers for node and edge features
|
||||
self.norm_node = LayerNorm(node_dim_in)
|
||||
self.norm_edge = LayerNorm(edge_dim_in)
|
||||
self.encoder_seq = SequenceWeight(node_dim_in, 1, dropout=dropout)
|
||||
|
||||
self.embed_x = nn.Sequential(nn.Linear(node_dim_in+21, node_dim_hidden), nn.ELU(inplace=True))
|
||||
self.embed_e = nn.Sequential(nn.Linear(edge_dim_in+1, edge_dim_hidden), nn.ELU(inplace=True))
|
||||
|
||||
# graph transformer
|
||||
blocks = [UniMPBlock(node_dim_hidden,edge_dim_hidden,nheads,dropout) for _ in range(nblocks)]
|
||||
self.transformer = nn.Sequential(*blocks)
|
||||
|
||||
# outputs
|
||||
self.get_xyz = nn.Linear(node_dim_hidden,9)
|
||||
|
||||
def forward(self, seq1hot, idx, msa, pair):
|
||||
B, N, L = msa.shape[:3]
|
||||
msa = self.norm_node(msa)
|
||||
pair = self.norm_edge(pair)
|
||||
|
||||
w_seq = self.encoder_seq(msa).reshape(B, L, 1, N).permute(0,3,1,2)
|
||||
msa = w_seq*msa
|
||||
msa = msa.sum(dim=1)
|
||||
node = torch.cat((msa, seq1hot), dim=-1)
|
||||
node = self.embed_x(node)
|
||||
|
||||
seqsep = get_seqsep(idx)
|
||||
pair = torch.cat((pair, seqsep), dim=-1)
|
||||
pair = self.embed_e(pair)
|
||||
|
||||
G = make_graph(node, idx, pair)
|
||||
Gout = self.transformer(G)
|
||||
|
||||
xyz = self.get_xyz(Gout.x)
|
||||
|
||||
return xyz.reshape(B, L, 3, 3) #torch.cat([xyz,node_emb],dim=-1)
|
175
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/Refine_module.py
Normal file
175
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/Refine_module.py
Normal file
|
@ -0,0 +1,175 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from Transformer import LayerNorm
|
||||
from InitStrGenerator import make_graph
|
||||
from InitStrGenerator import get_seqsep, UniMPBlock
|
||||
from Attention_module_w_str import make_graph as make_graph_topk
|
||||
from Attention_module_w_str import get_bonded_neigh, rbf
|
||||
from SE3_network import SE3Transformer
|
||||
from Transformer import _get_clones, create_custom_forward
|
||||
# Re-generate initial coordinates based on 1) final pair features 2) predicted distogram
|
||||
# Then, refine it through multiple SE3 transformer block
|
||||
|
||||
class Regen_Network(nn.Module):
|
||||
def __init__(self,
|
||||
node_dim_in=64,
|
||||
node_dim_hidden=64,
|
||||
edge_dim_in=128,
|
||||
edge_dim_hidden=64,
|
||||
state_dim=8,
|
||||
nheads=4,
|
||||
nblocks=3,
|
||||
dropout=0.0):
|
||||
super(Regen_Network, self).__init__()
|
||||
|
||||
# embedding layers for node and edge features
|
||||
self.norm_node = LayerNorm(node_dim_in)
|
||||
self.norm_edge = LayerNorm(edge_dim_in)
|
||||
|
||||
self.embed_x = nn.Sequential(nn.Linear(node_dim_in+21, node_dim_hidden), LayerNorm(node_dim_hidden))
|
||||
self.embed_e = nn.Sequential(nn.Linear(edge_dim_in+2, edge_dim_hidden), LayerNorm(edge_dim_hidden))
|
||||
|
||||
# graph transformer
|
||||
blocks = [UniMPBlock(node_dim_hidden,edge_dim_hidden,nheads,dropout) for _ in range(nblocks)]
|
||||
self.transformer = nn.Sequential(*blocks)
|
||||
|
||||
# outputs
|
||||
self.get_xyz = nn.Linear(node_dim_hidden,9)
|
||||
self.norm_state = LayerNorm(node_dim_hidden)
|
||||
self.get_state = nn.Linear(node_dim_hidden, state_dim)
|
||||
|
||||
def forward(self, seq1hot, idx, node, edge):
|
||||
B, L = node.shape[:2]
|
||||
node = self.norm_node(node)
|
||||
edge = self.norm_edge(edge)
|
||||
|
||||
node = torch.cat((node, seq1hot), dim=-1)
|
||||
node = self.embed_x(node)
|
||||
|
||||
seqsep = get_seqsep(idx)
|
||||
neighbor = get_bonded_neigh(idx)
|
||||
edge = torch.cat((edge, seqsep, neighbor), dim=-1)
|
||||
edge = self.embed_e(edge)
|
||||
|
||||
G = make_graph(node, idx, edge)
|
||||
Gout = self.transformer(G)
|
||||
|
||||
xyz = self.get_xyz(Gout.x)
|
||||
state = self.get_state(self.norm_state(Gout.x))
|
||||
return xyz.reshape(B, L, 3, 3) , state.reshape(B, L, -1)
|
||||
|
||||
class Refine_Network(nn.Module):
|
||||
def __init__(self, d_node=64, d_pair=128, d_state=16,
|
||||
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, p_drop=0.0):
|
||||
super(Refine_Network, self).__init__()
|
||||
self.norm_msa = LayerNorm(d_node)
|
||||
self.norm_pair = LayerNorm(d_pair)
|
||||
self.norm_state = LayerNorm(d_state)
|
||||
|
||||
self.embed_x = nn.Linear(d_node+21+d_state, SE3_param['l0_in_features'])
|
||||
self.embed_e1 = nn.Linear(d_pair, SE3_param['num_edge_features'])
|
||||
self.embed_e2 = nn.Linear(SE3_param['num_edge_features']+36+1, SE3_param['num_edge_features'])
|
||||
|
||||
self.norm_node = LayerNorm(SE3_param['l0_in_features'])
|
||||
self.norm_edge1 = LayerNorm(SE3_param['num_edge_features'])
|
||||
self.norm_edge2 = LayerNorm(SE3_param['num_edge_features'])
|
||||
|
||||
self.se3 = SE3Transformer(**SE3_param)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def forward(self, msa, pair, xyz, state, seq1hot, idx, top_k=64):
|
||||
# process node & pair features
|
||||
B, L = msa.shape[:2]
|
||||
node = self.norm_msa(msa)
|
||||
pair = self.norm_pair(pair)
|
||||
state = self.norm_state(state)
|
||||
|
||||
node = torch.cat((node, seq1hot, state), dim=-1)
|
||||
node = self.norm_node(self.embed_x(node))
|
||||
pair = self.norm_edge1(self.embed_e1(pair))
|
||||
|
||||
neighbor = get_bonded_neigh(idx)
|
||||
rbf_feat = rbf(torch.cdist(xyz[:,:,1,:], xyz[:,:,1,:]))
|
||||
pair = torch.cat((pair, rbf_feat, neighbor), dim=-1)
|
||||
pair = self.norm_edge2(self.embed_e2(pair))
|
||||
|
||||
# define graph
|
||||
G = make_graph_topk(xyz, pair, idx, top_k=top_k)
|
||||
l1_feats = xyz - xyz[:,:,1,:].unsqueeze(2) # l1 features = displacement vector to CA
|
||||
l1_feats = l1_feats.reshape(B*L, -1, 3)
|
||||
|
||||
# apply SE(3) Transformer & update coordinates
|
||||
shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats)
|
||||
|
||||
state = shift['0'].reshape(B, L, -1) # (B, L, C)
|
||||
|
||||
offset = shift['1'].reshape(B, L, -1, 3) # (B, L, 3, 3)
|
||||
CA_new = xyz[:,:,1] + offset[:,:,1]
|
||||
N_new = CA_new + offset[:,:,0]
|
||||
C_new = CA_new + offset[:,:,2]
|
||||
xyz_new = torch.stack([N_new, CA_new, C_new], dim=2)
|
||||
|
||||
return xyz_new, state
|
||||
|
||||
class Refine_module(nn.Module):
|
||||
def __init__(self, n_module, d_node=64, d_node_hidden=64, d_pair=128, d_pair_hidden=64,
|
||||
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, p_drop=0.0):
|
||||
super(Refine_module, self).__init__()
|
||||
self.n_module = n_module
|
||||
self.proj_edge = nn.Linear(d_pair, d_pair_hidden*2)
|
||||
|
||||
self.regen_net = Regen_Network(node_dim_in=d_node, node_dim_hidden=d_node_hidden,
|
||||
edge_dim_in=d_pair_hidden*2, edge_dim_hidden=d_pair_hidden,
|
||||
state_dim=SE3_param['l0_out_features'],
|
||||
nheads=4, nblocks=3, dropout=p_drop)
|
||||
self.refine_net = _get_clones(Refine_Network(d_node=d_node, d_pair=d_pair_hidden*2,
|
||||
d_state=SE3_param['l0_out_features'],
|
||||
SE3_param=SE3_param, p_drop=p_drop), self.n_module)
|
||||
self.norm_state = LayerNorm(SE3_param['l0_out_features'])
|
||||
self.pred_lddt = nn.Linear(SE3_param['l0_out_features'], 1)
|
||||
|
||||
def forward(self, node, edge, seq1hot, idx, use_transf_checkpoint=False, eps=1e-4):
|
||||
edge = self.proj_edge(edge)
|
||||
|
||||
xyz, state = self.regen_net(seq1hot, idx, node, edge)
|
||||
|
||||
# DOUBLE IT w/ Mirror images
|
||||
xyz = torch.cat([xyz, xyz*torch.tensor([1,1,-1], dtype=xyz.dtype, device=xyz.device)])
|
||||
state = torch.cat([state, state])
|
||||
node = torch.cat([node, node])
|
||||
edge = torch.cat([edge, edge])
|
||||
idx = torch.cat([idx, idx])
|
||||
seq1hot = torch.cat([seq1hot, seq1hot])
|
||||
|
||||
best_xyz = xyz
|
||||
best_lddt = torch.zeros((xyz.shape[0], xyz.shape[1], 1), device=xyz.device)
|
||||
prev_lddt = 0.0
|
||||
no_impr = 0
|
||||
no_impr_best = 0
|
||||
for i_iter in range(200):
|
||||
for i_m in range(self.n_module):
|
||||
if use_transf_checkpoint:
|
||||
xyz, state = checkpoint.checkpoint(create_custom_forward(self.refine_net[i_m], top_k=64), node.float(), edge.float(), xyz.detach().float(), state.float(), seq1hot, idx)
|
||||
else:
|
||||
xyz, state = self.refine_net[i_m](node.float(), edge.float(), xyz.detach().float(), state.float(), seq1hot, idx, top_k=64)
|
||||
#
|
||||
lddt = self.pred_lddt(self.norm_state(state))
|
||||
lddt = torch.clamp(lddt, 0.0, 1.0)[...,0]
|
||||
print (f"SE(3) iteration {i_iter} {lddt.mean(-1).cpu().numpy()}")
|
||||
if lddt.mean(-1).max() <= prev_lddt+eps:
|
||||
no_impr += 1
|
||||
else:
|
||||
no_impr = 0
|
||||
if lddt.mean(-1).max() <= best_lddt.mean(-1).max()+eps:
|
||||
no_impr_best += 1
|
||||
else:
|
||||
no_impr_best = 0
|
||||
if no_impr > 10 or no_impr_best > 20:
|
||||
break
|
||||
if lddt.mean(-1).max() > best_lddt.mean(-1).max():
|
||||
best_lddt = lddt
|
||||
best_xyz = xyz
|
||||
prev_lddt = lddt.mean(-1).max()
|
||||
pick = best_lddt.mean(-1).argmax()
|
||||
return best_xyz[pick][None], best_lddt[pick][None]
|
131
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/RoseTTAFoldModel.py
Normal file
131
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/RoseTTAFoldModel.py
Normal file
|
@ -0,0 +1,131 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from Embeddings import MSA_emb, Pair_emb_wo_templ, Pair_emb_w_templ, Templ_emb
|
||||
from Attention_module_w_str import IterativeFeatureExtractor
|
||||
from DistancePredictor import DistanceNetwork
|
||||
from Refine_module import Refine_module
|
||||
|
||||
class RoseTTAFoldModule(nn.Module):
|
||||
def __init__(self, n_module=4, n_module_str=4, n_layer=4,\
|
||||
d_msa=64, d_pair=128, d_templ=64,\
|
||||
n_head_msa=4, n_head_pair=8, n_head_templ=4,
|
||||
d_hidden=64, r_ff=4, n_resblock=1, p_drop=0.1,
|
||||
performer_L_opts=None, performer_N_opts=None,
|
||||
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
|
||||
use_templ=False):
|
||||
super(RoseTTAFoldModule, self).__init__()
|
||||
self.use_templ = use_templ
|
||||
#
|
||||
self.msa_emb = MSA_emb(d_model=d_msa, p_drop=p_drop, max_len=5000)
|
||||
if use_templ:
|
||||
self.templ_emb = Templ_emb(d_templ=d_templ, n_att_head=n_head_templ, r_ff=r_ff,
|
||||
performer_opts=performer_L_opts, p_drop=0.0)
|
||||
self.pair_emb = Pair_emb_w_templ(d_model=d_pair, d_templ=d_templ, p_drop=p_drop)
|
||||
else:
|
||||
self.pair_emb = Pair_emb_wo_templ(d_model=d_pair, p_drop=p_drop)
|
||||
#
|
||||
self.feat_extractor = IterativeFeatureExtractor(n_module=n_module,\
|
||||
n_module_str=n_module_str,\
|
||||
n_layer=n_layer,\
|
||||
d_msa=d_msa, d_pair=d_pair, d_hidden=d_hidden,\
|
||||
n_head_msa=n_head_msa, \
|
||||
n_head_pair=n_head_pair,\
|
||||
r_ff=r_ff, \
|
||||
n_resblock=n_resblock,
|
||||
p_drop=p_drop,
|
||||
performer_N_opts=performer_N_opts,
|
||||
performer_L_opts=performer_L_opts,
|
||||
SE3_param=SE3_param)
|
||||
self.c6d_predictor = DistanceNetwork(d_pair, p_drop=p_drop)
|
||||
|
||||
def forward(self, msa, seq, idx, t1d=None, t2d=None):
|
||||
B, N, L = msa.shape
|
||||
# Get embeddings
|
||||
msa = self.msa_emb(msa, idx)
|
||||
if self.use_templ:
|
||||
tmpl = self.templ_emb(t1d, t2d, idx)
|
||||
pair = self.pair_emb(seq, idx, tmpl)
|
||||
else:
|
||||
pair = self.pair_emb(seq, idx)
|
||||
#
|
||||
# Extract features
|
||||
seq1hot = torch.nn.functional.one_hot(seq, num_classes=21).float()
|
||||
msa, pair, xyz, lddt = self.feat_extractor(msa, pair, seq1hot, idx)
|
||||
|
||||
# Predict 6D coords
|
||||
logits = self.c6d_predictor(pair)
|
||||
|
||||
return logits, xyz, lddt.view(B, L)
|
||||
|
||||
|
||||
class RoseTTAFoldModule_e2e(nn.Module):
|
||||
def __init__(self, n_module=4, n_module_str=4, n_module_ref=4, n_layer=4,\
|
||||
d_msa=64, d_pair=128, d_templ=64,\
|
||||
n_head_msa=4, n_head_pair=8, n_head_templ=4,
|
||||
d_hidden=64, r_ff=4, n_resblock=1, p_drop=0.0,
|
||||
performer_L_opts=None, performer_N_opts=None,
|
||||
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
|
||||
REF_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
|
||||
use_templ=False):
|
||||
super(RoseTTAFoldModule_e2e, self).__init__()
|
||||
self.use_templ = use_templ
|
||||
#
|
||||
self.msa_emb = MSA_emb(d_model=d_msa, p_drop=p_drop, max_len=5000)
|
||||
if use_templ:
|
||||
self.templ_emb = Templ_emb(d_templ=d_templ, n_att_head=n_head_templ, r_ff=r_ff,
|
||||
performer_opts=performer_L_opts, p_drop=0.0)
|
||||
self.pair_emb = Pair_emb_w_templ(d_model=d_pair, d_templ=d_templ, p_drop=p_drop)
|
||||
else:
|
||||
self.pair_emb = Pair_emb_wo_templ(d_model=d_pair, p_drop=p_drop)
|
||||
#
|
||||
self.feat_extractor = IterativeFeatureExtractor(n_module=n_module,\
|
||||
n_module_str=n_module_str,\
|
||||
n_layer=n_layer,\
|
||||
d_msa=d_msa, d_pair=d_pair, d_hidden=d_hidden,\
|
||||
n_head_msa=n_head_msa, \
|
||||
n_head_pair=n_head_pair,\
|
||||
r_ff=r_ff, \
|
||||
n_resblock=n_resblock,
|
||||
p_drop=p_drop,
|
||||
performer_N_opts=performer_N_opts,
|
||||
performer_L_opts=performer_L_opts,
|
||||
SE3_param=SE3_param)
|
||||
self.c6d_predictor = DistanceNetwork(d_pair, p_drop=p_drop)
|
||||
#
|
||||
self.refine = Refine_module(n_module_ref, d_node=d_msa, d_pair=130,
|
||||
d_node_hidden=d_hidden, d_pair_hidden=d_hidden,
|
||||
SE3_param=REF_param, p_drop=p_drop)
|
||||
|
||||
def forward(self, msa, seq, idx, t1d=None, t2d=None, prob_s=None, return_raw=False, refine_only=False):
|
||||
seq1hot = torch.nn.functional.one_hot(seq, num_classes=21).float()
|
||||
if not refine_only:
|
||||
B, N, L = msa.shape
|
||||
# Get embeddings
|
||||
msa = self.msa_emb(msa, idx)
|
||||
if self.use_templ:
|
||||
tmpl = self.templ_emb(t1d, t2d, idx)
|
||||
pair = self.pair_emb(seq, idx, tmpl)
|
||||
else:
|
||||
pair = self.pair_emb(seq, idx)
|
||||
#
|
||||
# Extract features
|
||||
msa, pair, xyz, lddt = self.feat_extractor(msa, pair, seq1hot, idx)
|
||||
|
||||
# Predict 6D coords
|
||||
logits = self.c6d_predictor(pair)
|
||||
|
||||
prob_s = list()
|
||||
for l in logits:
|
||||
prob_s.append(nn.Softmax(dim=1)(l)) # (B, C, L, L)
|
||||
prob_s = torch.cat(prob_s, dim=1).permute(0,2,3,1)
|
||||
|
||||
B, L = msa.shape[:2]
|
||||
if return_raw:
|
||||
return logits, msa, xyz, lddt.view(B, L)
|
||||
|
||||
ref_xyz, ref_lddt = self.refine(msa, prob_s, seq1hot, idx)
|
||||
|
||||
if refine_only:
|
||||
return ref_xyz, ref_lddt.view(B,L)
|
||||
else:
|
||||
return logits, msa, ref_xyz, ref_lddt.view(B,L)
|
108
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/SE3_network.py
Normal file
108
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/SE3_network.py
Normal file
|
@ -0,0 +1,108 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from equivariant_attention.modules import get_basis_and_r, GSE3Res, GNormBias
|
||||
from equivariant_attention.modules import GConvSE3, GNormSE3
|
||||
from equivariant_attention.fibers import Fiber
|
||||
|
||||
class TFN(nn.Module):
|
||||
"""SE(3) equivariant GCN"""
|
||||
def __init__(self, num_layers=2, num_channels=32, num_nonlin_layers=1, num_degrees=3,
|
||||
l0_in_features=32, l0_out_features=32,
|
||||
l1_in_features=3, l1_out_features=3,
|
||||
num_edge_features=32, use_self=True):
|
||||
super().__init__()
|
||||
# Build the network
|
||||
self.num_layers = num_layers
|
||||
self.num_nlayers = num_nonlin_layers
|
||||
self.num_channels = num_channels
|
||||
self.num_degrees = num_degrees
|
||||
self.edge_dim = num_edge_features
|
||||
self.use_self = use_self
|
||||
|
||||
if l1_out_features > 0:
|
||||
fibers = {'in': Fiber(dictionary={0: l0_in_features, 1: l1_in_features}),
|
||||
'mid': Fiber(self.num_degrees, self.num_channels),
|
||||
'out': Fiber(dictionary={0: l0_out_features, 1: l1_out_features})}
|
||||
else:
|
||||
fibers = {'in': Fiber(dictionary={0: l0_in_features, 1: l1_in_features}),
|
||||
'mid': Fiber(self.num_degrees, self.num_channels),
|
||||
'out': Fiber(dictionary={0: l0_out_features})}
|
||||
blocks = self._build_gcn(fibers)
|
||||
self.block0 = blocks
|
||||
|
||||
def _build_gcn(self, fibers):
|
||||
|
||||
block0 = []
|
||||
fin = fibers['in']
|
||||
for i in range(self.num_layers-1):
|
||||
block0.append(GConvSE3(fin, fibers['mid'], self_interaction=self.use_self, edge_dim=self.edge_dim))
|
||||
block0.append(GNormSE3(fibers['mid'], num_layers=self.num_nlayers))
|
||||
fin = fibers['mid']
|
||||
block0.append(GConvSE3(fibers['mid'], fibers['out'], self_interaction=self.use_self, edge_dim=self.edge_dim))
|
||||
return nn.ModuleList(block0)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def forward(self, G, type_0_features, type_1_features):
|
||||
# Compute equivariant weight basis from relative positions
|
||||
basis, r = get_basis_and_r(G, self.num_degrees-1)
|
||||
h = {'0': type_0_features, '1': type_1_features}
|
||||
for layer in self.block0:
|
||||
h = layer(h, G=G, r=r, basis=basis)
|
||||
return h
|
||||
|
||||
class SE3Transformer(nn.Module):
|
||||
"""SE(3) equivariant GCN with attention"""
|
||||
def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4,
|
||||
si_m='1x1', si_e='att',
|
||||
l0_in_features=32, l0_out_features=32,
|
||||
l1_in_features=3, l1_out_features=3,
|
||||
num_edge_features=32, x_ij=None):
|
||||
super().__init__()
|
||||
# Build the network
|
||||
self.num_layers = num_layers
|
||||
self.num_channels = num_channels
|
||||
self.num_degrees = num_degrees
|
||||
self.edge_dim = num_edge_features
|
||||
self.div = div
|
||||
self.n_heads = n_heads
|
||||
self.si_m, self.si_e = si_m, si_e
|
||||
self.x_ij = x_ij
|
||||
|
||||
if l1_out_features > 0:
|
||||
fibers = {'in': Fiber(dictionary={0: l0_in_features, 1: l1_in_features}),
|
||||
'mid': Fiber(self.num_degrees, self.num_channels),
|
||||
'out': Fiber(dictionary={0: l0_out_features, 1: l1_out_features})}
|
||||
else:
|
||||
fibers = {'in': Fiber(dictionary={0: l0_in_features, 1: l1_in_features}),
|
||||
'mid': Fiber(self.num_degrees, self.num_channels),
|
||||
'out': Fiber(dictionary={0: l0_out_features})}
|
||||
|
||||
blocks = self._build_gcn(fibers)
|
||||
self.Gblock = blocks
|
||||
|
||||
def _build_gcn(self, fibers):
|
||||
# Equivariant layers
|
||||
Gblock = []
|
||||
fin = fibers['in']
|
||||
for i in range(self.num_layers):
|
||||
Gblock.append(GSE3Res(fin, fibers['mid'], edge_dim=self.edge_dim,
|
||||
div=self.div, n_heads=self.n_heads,
|
||||
learnable_skip=True, skip='cat',
|
||||
selfint=self.si_m, x_ij=self.x_ij))
|
||||
Gblock.append(GNormBias(fibers['mid']))
|
||||
fin = fibers['mid']
|
||||
Gblock.append(
|
||||
GSE3Res(fibers['mid'], fibers['out'], edge_dim=self.edge_dim,
|
||||
div=1, n_heads=min(1, 2), learnable_skip=True,
|
||||
skip='cat', selfint=self.si_e, x_ij=self.x_ij))
|
||||
return nn.ModuleList(Gblock)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def forward(self, G, type_0_features, type_1_features):
|
||||
# Compute equivariant weight basis from relative positions
|
||||
basis, r = get_basis_and_r(G, self.num_degrees-1)
|
||||
h = {'0': type_0_features, '1': type_1_features}
|
||||
for layer in self.Gblock:
|
||||
h = layer(h, G=G, r=r, basis=basis)
|
||||
return h
|
480
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/Transformer.py
Normal file
480
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/Transformer.py
Normal file
|
@ -0,0 +1,480 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import copy
|
||||
import math
|
||||
from performer_pytorch import SelfAttention
|
||||
|
||||
def _get_clones(module, N):
|
||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
# for gradient checkpointing
|
||||
def create_custom_forward(module, **kwargs):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, **kwargs)
|
||||
return custom_forward
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, d_model, eps=1e-5):
|
||||
super(LayerNorm, self).__init__()
|
||||
self.a_2 = nn.Parameter(torch.ones(d_model))
|
||||
self.b_2 = nn.Parameter(torch.zeros(d_model))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
mean = x.mean(-1, keepdim=True)
|
||||
std = torch.sqrt(x.var(dim=-1, keepdim=True, unbiased=False) + self.eps)
|
||||
x = self.a_2*(x-mean)
|
||||
x /= std
|
||||
x += self.b_2
|
||||
return x
|
||||
|
||||
class FeedForwardLayer(nn.Module):
|
||||
def __init__(self, d_model, d_ff, p_drop=0.1):
|
||||
super(FeedForwardLayer, self).__init__()
|
||||
self.linear1 = nn.Linear(d_model, d_ff)
|
||||
self.dropout = nn.Dropout(p_drop, inplace=True)
|
||||
self.linear2 = nn.Linear(d_ff, d_model)
|
||||
|
||||
def forward(self, src):
|
||||
src = self.linear2(self.dropout(F.relu_(self.linear1(src))))
|
||||
return src
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
def __init__(self, d_model, heads, k_dim=None, v_dim=None, dropout=0.1):
|
||||
super(MultiheadAttention, self).__init__()
|
||||
if k_dim == None:
|
||||
k_dim = d_model
|
||||
if v_dim == None:
|
||||
v_dim = d_model
|
||||
|
||||
self.heads = heads
|
||||
self.d_model = d_model
|
||||
self.d_k = d_model // heads
|
||||
self.scaling = 1/math.sqrt(self.d_k)
|
||||
|
||||
self.to_query = nn.Linear(d_model, d_model)
|
||||
self.to_key = nn.Linear(k_dim, d_model)
|
||||
self.to_value = nn.Linear(v_dim, d_model)
|
||||
self.to_out = nn.Linear(d_model, d_model)
|
||||
|
||||
self.dropout = nn.Dropout(dropout, inplace=True)
|
||||
|
||||
def forward(self, query, key, value, return_att=False):
|
||||
batch, L1 = query.shape[:2]
|
||||
batch, L2 = key.shape[:2]
|
||||
q = self.to_query(query).view(batch, L1, self.heads, self.d_k).permute(0,2,1,3) # (B, h, L, d_k)
|
||||
k = self.to_key(key).view(batch, L2, self.heads, self.d_k).permute(0,2,1,3) # (B, h, L, d_k)
|
||||
v = self.to_value(value).view(batch, L2, self.heads, self.d_k).permute(0,2,1,3)
|
||||
#
|
||||
attention = torch.matmul(q, k.transpose(-2, -1))*self.scaling
|
||||
attention = F.softmax(attention, dim=-1) # (B, h, L1, L2)
|
||||
attention = self.dropout(attention)
|
||||
#
|
||||
out = torch.matmul(attention, v) # (B, h, L, d_k)
|
||||
out = out.permute(0,2,1,3).contiguous().view(batch, L1, -1)
|
||||
#
|
||||
out = self.to_out(out)
|
||||
if return_att:
|
||||
attention = 0.5*(attention + attention.permute(0,1,3,2))
|
||||
return out, attention.permute(0,2,3,1)
|
||||
return out
|
||||
|
||||
# Own implementation for tied multihead attention
|
||||
class TiedMultiheadAttention(nn.Module):
|
||||
def __init__(self, d_model, heads, k_dim=None, v_dim=None, dropout=0.1):
|
||||
super(TiedMultiheadAttention, self).__init__()
|
||||
if k_dim == None:
|
||||
k_dim = d_model
|
||||
if v_dim == None:
|
||||
v_dim = d_model
|
||||
|
||||
self.heads = heads
|
||||
self.d_model = d_model
|
||||
self.d_k = d_model // heads
|
||||
self.scaling = 1/math.sqrt(self.d_k)
|
||||
|
||||
self.to_query = nn.Linear(d_model, d_model)
|
||||
self.to_key = nn.Linear(k_dim, d_model)
|
||||
self.to_value = nn.Linear(v_dim, d_model)
|
||||
self.to_out = nn.Linear(d_model, d_model)
|
||||
|
||||
self.dropout = nn.Dropout(dropout, inplace=True)
|
||||
|
||||
def forward(self, query, key, value, return_att=False):
|
||||
B, N, L = query.shape[:3]
|
||||
q = self.to_query(query).view(B, N, L, self.heads, self.d_k).permute(0,1,3,2,4).contiguous() # (B, N, h, l, k)
|
||||
k = self.to_key(key).view(B, N, L, self.heads, self.d_k).permute(0,1,3,4,2).contiguous() # (B, N, h, k, l)
|
||||
v = self.to_value(value).view(B, N, L, self.heads, self.d_k).permute(0,1,3,2,4).contiguous() # (B, N, h, l, k)
|
||||
#
|
||||
#attention = torch.matmul(q, k.transpose(-2, -1))/math.sqrt(N*self.d_k) # (B, N, h, L, L)
|
||||
#attention = attention.sum(dim=1) # tied attention (B, h, L, L)
|
||||
scale = self.scaling / math.sqrt(N)
|
||||
q = q * scale
|
||||
attention = torch.einsum('bnhik,bnhkj->bhij', q, k)
|
||||
attention = F.softmax(attention, dim=-1) # (B, h, L, L)
|
||||
attention = self.dropout(attention)
|
||||
attention = attention.unsqueeze(1) # (B, 1, h, L, L)
|
||||
#
|
||||
out = torch.matmul(attention, v) # (B, N, h, L, d_k)
|
||||
out = out.permute(0,1,3,2,4).contiguous().view(B, N, L, -1)
|
||||
#
|
||||
out = self.to_out(out)
|
||||
if return_att:
|
||||
attention = attention.squeeze(1)
|
||||
attention = 0.5*(attention + attention.permute(0,1,3,2))
|
||||
attention = attention.permute(0,3,1,2)
|
||||
return out, attention
|
||||
return out
|
||||
|
||||
class SequenceWeight(nn.Module):
|
||||
def __init__(self, d_model, heads, dropout=0.1):
|
||||
super(SequenceWeight, self).__init__()
|
||||
self.heads = heads
|
||||
self.d_model = d_model
|
||||
self.d_k = d_model // heads
|
||||
self.scale = 1.0 / math.sqrt(self.d_k)
|
||||
|
||||
self.to_query = nn.Linear(d_model, d_model)
|
||||
self.to_key = nn.Linear(d_model, d_model)
|
||||
self.dropout = nn.Dropout(dropout, inplace=True)
|
||||
|
||||
def forward(self, msa):
|
||||
B, N, L = msa.shape[:3]
|
||||
|
||||
msa = msa.permute(0,2,1,3) # (B, L, N, K)
|
||||
tar_seq = msa[:,:,0].unsqueeze(2) # (B, L, 1, K)
|
||||
|
||||
q = self.to_query(tar_seq).view(B, L, 1, self.heads, self.d_k).permute(0,1,3,2,4).contiguous() # (B, L, h, 1, k)
|
||||
k = self.to_key(msa).view(B, L, N, self.heads, self.d_k).permute(0,1,3,4,2).contiguous() # (B, L, h, k, N)
|
||||
|
||||
q = q * self.scale
|
||||
attn = torch.matmul(q, k) # (B, L, h, 1, N)
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
return self.dropout(attn)
|
||||
|
||||
# Own implementation for multihead attention (Input shape: Batch, Len, Emb)
|
||||
class SoftTiedMultiheadAttention(nn.Module):
|
||||
def __init__(self, d_model, heads, k_dim=None, v_dim=None, dropout=0.1):
|
||||
super(SoftTiedMultiheadAttention, self).__init__()
|
||||
if k_dim == None:
|
||||
k_dim = d_model
|
||||
if v_dim == None:
|
||||
v_dim = d_model
|
||||
|
||||
self.heads = heads
|
||||
self.d_model = d_model
|
||||
self.d_k = d_model // heads
|
||||
self.scale = 1.0 / math.sqrt(self.d_k)
|
||||
|
||||
self.seq_weight = SequenceWeight(d_model, heads, dropout=dropout)
|
||||
self.to_query = nn.Linear(d_model, d_model)
|
||||
self.to_key = nn.Linear(k_dim, d_model)
|
||||
self.to_value = nn.Linear(v_dim, d_model)
|
||||
self.to_out = nn.Linear(d_model, d_model)
|
||||
|
||||
self.dropout = nn.Dropout(dropout, inplace=True)
|
||||
|
||||
def forward(self, query, key, value, return_att=False):
|
||||
B, N, L = query.shape[:3]
|
||||
#
|
||||
seq_weight = self.seq_weight(query) # (B, L, h, 1, N)
|
||||
seq_weight = seq_weight.permute(0,4,2,1,3) # (B, N, h, l, -1)
|
||||
#
|
||||
q = self.to_query(query).view(B, N, L, self.heads, self.d_k).permute(0,1,3,2,4).contiguous() # (B, N, h, l, k)
|
||||
k = self.to_key(key).view(B, N, L, self.heads, self.d_k).permute(0,1,3,4,2).contiguous() # (B, N, h, k, l)
|
||||
v = self.to_value(value).view(B, N, L, self.heads, self.d_k).permute(0,1,3,2,4).contiguous() # (B, N, h, l, k)
|
||||
#
|
||||
#attention = torch.matmul(q, k.transpose(-2, -1))/math.sqrt(N*self.d_k) # (B, N, h, L, L)
|
||||
#attention = attention.sum(dim=1) # tied attention (B, h, L, L)
|
||||
q = q * seq_weight # (B, N, h, l, k)
|
||||
k = k * self.scale
|
||||
attention = torch.einsum('bnhik,bnhkj->bhij', q, k)
|
||||
attention = F.softmax(attention, dim=-1) # (B, h, L, L)
|
||||
attention = self.dropout(attention)
|
||||
attention = attention # (B, 1, h, L, L)
|
||||
del q, k, seq_weight
|
||||
#
|
||||
#out = torch.matmul(attention, v) # (B, N, h, L, d_k)
|
||||
out = torch.einsum('bhij,bnhjk->bnhik', attention, v)
|
||||
out = out.permute(0,1,3,2,4).contiguous().view(B, N, L, -1)
|
||||
#
|
||||
out = self.to_out(out)
|
||||
|
||||
if return_att:
|
||||
attention = attention.squeeze(1)
|
||||
attention = 0.5*(attention + attention.permute(0,1,3,2))
|
||||
attention = attention.permute(0,2,3,1)
|
||||
return out, attention
|
||||
return out
|
||||
|
||||
class DirectMultiheadAttention(nn.Module):
|
||||
def __init__(self, d_in, d_out, heads, dropout=0.1):
|
||||
super(DirectMultiheadAttention, self).__init__()
|
||||
self.heads = heads
|
||||
self.proj_pair = nn.Linear(d_in, heads)
|
||||
self.drop = nn.Dropout(dropout, inplace=True)
|
||||
# linear projection to get values from given msa
|
||||
self.proj_msa = nn.Linear(d_out, d_out)
|
||||
# projection after applying attention
|
||||
self.proj_out = nn.Linear(d_out, d_out)
|
||||
|
||||
def forward(self, src, tgt):
|
||||
B, N, L = tgt.shape[:3]
|
||||
attn_map = F.softmax(self.proj_pair(src), dim=1).permute(0,3,1,2) # (B, h, L, L)
|
||||
attn_map = self.drop(attn_map).unsqueeze(1)
|
||||
|
||||
# apply attention
|
||||
value = self.proj_msa(tgt).permute(0,3,1,2).contiguous().view(B, -1, self.heads, N, L) # (B,-1, h, N, L)
|
||||
tgt = torch.matmul(value, attn_map).view(B, -1, N, L).permute(0,2,3,1) # (B,N,L,K)
|
||||
tgt = self.proj_out(tgt)
|
||||
return tgt
|
||||
|
||||
class MaskedDirectMultiheadAttention(nn.Module):
|
||||
def __init__(self, d_in, d_out, heads, d_k=32, dropout=0.1):
|
||||
super(MaskedDirectMultiheadAttention, self).__init__()
|
||||
self.heads = heads
|
||||
self.scaling = 1/math.sqrt(d_k)
|
||||
|
||||
self.to_query = nn.Linear(d_in, heads*d_k)
|
||||
self.to_key = nn.Linear(d_in, heads*d_k)
|
||||
self.to_value = nn.Linear(d_out, d_out)
|
||||
self.to_out = nn.Linear(d_out, d_out)
|
||||
self.dropout = nn.Dropout(dropout, inplace=True)
|
||||
|
||||
def forward(self, query, key, value, mask):
|
||||
batch, N, L = value.shape[:3]
|
||||
#
|
||||
# project to query, key, value
|
||||
q = self.to_query(query).view(batch, L, self.heads, -1).permute(0,2,1,3) # (B, h, L, -1)
|
||||
k = self.to_key(key).view(batch, L, self.heads, -1).permute(0,2,1,3) # (B, h, L, -1)
|
||||
v = self.to_value(value).view(batch, N, L, self.heads, -1).permute(0,3,1,2,4) # (B, h, N, L, -1)
|
||||
#
|
||||
q = q*self.scaling
|
||||
attention = torch.matmul(q, k.transpose(-2, -1)) # (B, h, L, L)
|
||||
attention = attention.masked_fill(mask < 0.5, torch.finfo(q.dtype).min)
|
||||
attention = F.softmax(attention, dim=-1) # (B, h, L1, L2)
|
||||
attention = self.dropout(attention) # (B, h, 1, L, L)
|
||||
#
|
||||
#out = torch.matmul(attention, v) # (B, h, N, L, d_out//h)
|
||||
out = torch.einsum('bhij,bhnjk->bhnik', attention, v) # (B, h, N, L, d_out//h)
|
||||
out = out.permute(0,2,3,1,4).contiguous().view(batch, N, L, -1)
|
||||
#
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
# Use PreLayerNorm for more stable training
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(self, d_model, d_ff, heads, p_drop=0.1, performer_opts=None, use_tied=False):
|
||||
super(EncoderLayer, self).__init__()
|
||||
self.use_performer = performer_opts is not None
|
||||
self.use_tied = use_tied
|
||||
# multihead attention
|
||||
if self.use_performer:
|
||||
self.attn = SelfAttention(dim=d_model, heads=heads, dropout=p_drop,
|
||||
generalized_attention=True, **performer_opts)
|
||||
elif use_tied:
|
||||
self.attn = SoftTiedMultiheadAttention(d_model, heads, dropout=p_drop)
|
||||
else:
|
||||
self.attn = MultiheadAttention(d_model, heads, dropout=p_drop)
|
||||
# feedforward
|
||||
self.ff = FeedForwardLayer(d_model, d_ff, p_drop=p_drop)
|
||||
|
||||
# normalization module
|
||||
self.norm1 = LayerNorm(d_model)
|
||||
self.norm2 = LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(p_drop, inplace=True)
|
||||
self.dropout2 = nn.Dropout(p_drop, inplace=True)
|
||||
|
||||
def forward(self, src, return_att=False):
|
||||
# Input shape for multihead attention: (BATCH, SRCLEN, EMB)
|
||||
# multihead attention w/ pre-LayerNorm
|
||||
B, N, L = src.shape[:3]
|
||||
src2 = self.norm1(src)
|
||||
if not self.use_tied:
|
||||
src2 = src2.reshape(B*N, L, -1)
|
||||
if return_att:
|
||||
src2, att = self.attn(src2, src2, src2, return_att=return_att)
|
||||
src2 = src2.reshape(B,N,L,-1)
|
||||
else:
|
||||
src2 = self.attn(src2, src2, src2).reshape(B,N,L,-1)
|
||||
src = src + self.dropout1(src2)
|
||||
|
||||
# feed-forward
|
||||
src2 = self.norm2(src) # pre-normalization
|
||||
src2 = self.ff(src2)
|
||||
src = src + self.dropout2(src2)
|
||||
if return_att:
|
||||
return src, att
|
||||
return src
|
||||
|
||||
# AxialTransformer with tied attention for L dimension
|
||||
class AxialEncoderLayer(nn.Module):
|
||||
def __init__(self, d_model, d_ff, heads, p_drop=0.1, performer_opts=None,
|
||||
use_tied_row=False, use_tied_col=False, use_soft_row=False):
|
||||
super(AxialEncoderLayer, self).__init__()
|
||||
self.use_performer = performer_opts is not None
|
||||
self.use_tied_row = use_tied_row
|
||||
self.use_tied_col = use_tied_col
|
||||
self.use_soft_row = use_soft_row
|
||||
# multihead attention
|
||||
if use_tied_row:
|
||||
self.attn_L = TiedMultiheadAttention(d_model, heads, dropout=p_drop)
|
||||
elif use_soft_row:
|
||||
self.attn_L = SoftTiedMultiheadAttention(d_model, heads, dropout=p_drop)
|
||||
else:
|
||||
if self.use_performer:
|
||||
self.attn_L = SelfAttention(dim=d_model, heads=heads, dropout=p_drop,
|
||||
generalized_attention=True, **performer_opts)
|
||||
else:
|
||||
self.attn_L = MultiheadAttention(d_model, heads, dropout=p_drop)
|
||||
if use_tied_col:
|
||||
self.attn_N = TiedMultiheadAttention(d_model, heads, dropout=p_drop)
|
||||
else:
|
||||
if self.use_performer:
|
||||
self.attn_N = SelfAttention(dim=d_model, heads=heads, dropout=p_drop,
|
||||
generalized_attention=True, **performer_opts)
|
||||
else:
|
||||
self.attn_N = MultiheadAttention(d_model, heads, dropout=p_drop)
|
||||
|
||||
# feedforward
|
||||
self.ff = FeedForwardLayer(d_model, d_ff, p_drop=p_drop)
|
||||
|
||||
# normalization module
|
||||
self.norm1 = LayerNorm(d_model)
|
||||
self.norm2 = LayerNorm(d_model)
|
||||
self.norm3 = LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(p_drop, inplace=True)
|
||||
self.dropout2 = nn.Dropout(p_drop, inplace=True)
|
||||
self.dropout3 = nn.Dropout(p_drop, inplace=True)
|
||||
|
||||
def forward(self, src, return_att=False):
|
||||
# Input shape for multihead attention: (BATCH, NSEQ, NRES, EMB)
|
||||
# Tied multihead attention w/ pre-LayerNorm
|
||||
B, N, L = src.shape[:3]
|
||||
src2 = self.norm1(src)
|
||||
if self.use_tied_row or self.use_soft_row:
|
||||
src2 = self.attn_L(src2, src2, src2) # Tied attention over L
|
||||
else:
|
||||
src2 = src2.reshape(B*N, L, -1)
|
||||
src2 = self.attn_L(src2, src2, src2)
|
||||
src2 = src2.reshape(B, N, L, -1)
|
||||
src = src + self.dropout1(src2)
|
||||
|
||||
# attention over N
|
||||
src2 = self.norm2(src)
|
||||
if self.use_tied_col:
|
||||
src2 = src2.permute(0,2,1,3)
|
||||
src2 = self.attn_N(src2, src2, src2) # Tied attention over N
|
||||
src2 = src2.permute(0,2,1,3)
|
||||
else:
|
||||
src2 = src2.permute(0,2,1,3).reshape(B*L, N, -1)
|
||||
src2 = self.attn_N(src2, src2, src2) # attention over N
|
||||
src2 = src2.reshape(B, L, N, -1).permute(0,2,1,3)
|
||||
src = src + self.dropout2(src2)
|
||||
|
||||
# feed-forward
|
||||
src2 = self.norm3(src) # pre-normalization
|
||||
src2 = self.ff(src2)
|
||||
src = src + self.dropout3(src2)
|
||||
return src
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, enc_layer, n_layer):
|
||||
super(Encoder, self).__init__()
|
||||
self.layers = _get_clones(enc_layer, n_layer)
|
||||
self.n_layer = n_layer
|
||||
|
||||
def forward(self, src, return_att=False):
|
||||
output = src
|
||||
for layer in self.layers:
|
||||
output = layer(output, return_att=return_att)
|
||||
return output
|
||||
|
||||
class CrossEncoderLayer(nn.Module):
|
||||
def __init__(self, d_model, d_ff, heads, d_k, d_v, performer_opts=None, p_drop=0.1):
|
||||
super(CrossEncoderLayer, self).__init__()
|
||||
self.use_performer = performer_opts is not None
|
||||
|
||||
# multihead attention
|
||||
if self.use_performer:
|
||||
self.attn = SelfAttention(dim=d_model, k_dim=d_k, heads=heads, dropout=p_drop,
|
||||
generalized_attention=True, **performer_opts)
|
||||
else:
|
||||
self.attn = MultiheadAttention(d_model, heads, k_dim=d_k, v_dim=d_v, dropout=p_drop)
|
||||
# feedforward
|
||||
self.ff = FeedForwardLayer(d_model, d_ff, p_drop=p_drop)
|
||||
|
||||
# normalization module
|
||||
self.norm = LayerNorm(d_k)
|
||||
self.norm1 = LayerNorm(d_model)
|
||||
self.norm2 = LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(p_drop, inplace=True)
|
||||
self.dropout2 = nn.Dropout(p_drop, inplace=True)
|
||||
|
||||
def forward(self, src, tgt):
|
||||
# Input:
|
||||
# For MSA to Pair: src (N, L, K), tgt (L, L, C)
|
||||
# For Pair to MSA: src (L, L, C), tgt (N, L, K)
|
||||
# Input shape for multihead attention: (SRCLEN, BATCH, EMB)
|
||||
# multihead attention
|
||||
# pre-normalization
|
||||
src = self.norm(src)
|
||||
tgt2 = self.norm1(tgt)
|
||||
tgt2 = self.attn(tgt2, src, src) # projection to query, key, value are done in MultiheadAttention module
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
|
||||
# Feed forward
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.ff(tgt2)
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
|
||||
return tgt
|
||||
|
||||
class DirectEncoderLayer(nn.Module):
|
||||
def __init__(self, heads, d_in, d_out, d_ff, symmetrize=True, p_drop=0.1):
|
||||
super(DirectEncoderLayer, self).__init__()
|
||||
self.symmetrize = symmetrize
|
||||
|
||||
self.attn = DirectMultiheadAttention(d_in, d_out, heads, dropout=p_drop)
|
||||
self.ff = FeedForwardLayer(d_out, d_ff, p_drop=p_drop)
|
||||
|
||||
# dropouts
|
||||
self.drop_1 = nn.Dropout(p_drop, inplace=True)
|
||||
self.drop_2 = nn.Dropout(p_drop, inplace=True)
|
||||
# LayerNorm
|
||||
self.norm = LayerNorm(d_in)
|
||||
self.norm1 = LayerNorm(d_out)
|
||||
self.norm2 = LayerNorm(d_out)
|
||||
|
||||
def forward(self, src, tgt):
|
||||
# Input:
|
||||
# For pair to msa: src=pair (B, L, L, C), tgt=msa (B, N, L, K)
|
||||
B, N, L = tgt.shape[:3]
|
||||
# get attention map
|
||||
if self.symmetrize:
|
||||
src = 0.5*(src + src.permute(0,2,1,3))
|
||||
src = self.norm(src)
|
||||
tgt2 = self.norm1(tgt)
|
||||
tgt2 = self.attn(src, tgt2)
|
||||
tgt = tgt + self.drop_1(tgt2)
|
||||
|
||||
# feed-forward
|
||||
tgt2 = self.norm2(tgt.view(B*N,L,-1)).view(B,N,L,-1)
|
||||
tgt2 = self.ff(tgt2)
|
||||
tgt = tgt + self.drop_2(tgt2)
|
||||
|
||||
return tgt
|
||||
|
||||
class CrossEncoder(nn.Module):
|
||||
def __init__(self, enc_layer, n_layer):
|
||||
super(CrossEncoder, self).__init__()
|
||||
self.layers = _get_clones(enc_layer, n_layer)
|
||||
self.n_layer = n_layer
|
||||
def forward(self, src, tgt):
|
||||
output = tgt
|
||||
for layer in self.layers:
|
||||
output = layer(src, output)
|
||||
return output
|
||||
|
||||
|
|
@ -0,0 +1,163 @@
|
|||
from utils.utils_profiling import * # load before other local modules
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import copy
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
|
||||
class Fiber(object):
|
||||
"""A Handy Data Structure for Fibers"""
|
||||
def __init__(self, num_degrees: int=None, num_channels: int=None,
|
||||
structure: List[Tuple[int,int]]=None, dictionary=None):
|
||||
"""
|
||||
define fiber structure; use one num_degrees & num_channels OR structure
|
||||
OR dictionary
|
||||
|
||||
:param num_degrees: degrees will be [0, ..., num_degrees-1]
|
||||
:param num_channels: number of channels, same for each degree
|
||||
:param structure: e.g. [(32, 0),(16, 1),(16,2)]
|
||||
:param dictionary: e.g. {0:32, 1:16, 2:16}
|
||||
"""
|
||||
if structure:
|
||||
self.structure = structure
|
||||
elif dictionary:
|
||||
self.structure = [(dictionary[o], o) for o in sorted(dictionary.keys())]
|
||||
else:
|
||||
self.structure = [(num_channels, i) for i in range(num_degrees)]
|
||||
|
||||
self.multiplicities, self.degrees = zip(*self.structure)
|
||||
self.max_degree = max(self.degrees)
|
||||
self.min_degree = min(self.degrees)
|
||||
self.structure_dict = {k: v for v, k in self.structure}
|
||||
self.dict = self.structure_dict
|
||||
self.n_features = np.sum([i[0] * (2*i[1]+1) for i in self.structure])
|
||||
|
||||
self.feature_indices = {}
|
||||
idx = 0
|
||||
for (num_channels, d) in self.structure:
|
||||
length = num_channels * (2*d + 1)
|
||||
self.feature_indices[d] = (idx, idx + length)
|
||||
idx += length
|
||||
|
||||
def copy_me(self, multiplicity: int=None):
|
||||
s = copy.deepcopy(self.structure)
|
||||
if multiplicity is not None:
|
||||
# overwrite multiplicities
|
||||
s = [(multiplicity, o) for m, o in s]
|
||||
return Fiber(structure=s)
|
||||
|
||||
@staticmethod
|
||||
def combine(f1, f2):
|
||||
new_dict = copy.deepcopy(f1.structure_dict)
|
||||
for k, m in f2.structure_dict.items():
|
||||
if k in new_dict.keys():
|
||||
new_dict[k] += m
|
||||
else:
|
||||
new_dict[k] = m
|
||||
structure = [(new_dict[k], k) for k in sorted(new_dict.keys())]
|
||||
return Fiber(structure=structure)
|
||||
|
||||
@staticmethod
|
||||
def combine_max(f1, f2):
|
||||
new_dict = copy.deepcopy(f1.structure_dict)
|
||||
for k, m in f2.structure_dict.items():
|
||||
if k in new_dict.keys():
|
||||
new_dict[k] = max(m, new_dict[k])
|
||||
else:
|
||||
new_dict[k] = m
|
||||
structure = [(new_dict[k], k) for k in sorted(new_dict.keys())]
|
||||
return Fiber(structure=structure)
|
||||
|
||||
@staticmethod
|
||||
def combine_selectively(f1, f2):
|
||||
# only use orders which occur in fiber f1
|
||||
|
||||
new_dict = copy.deepcopy(f1.structure_dict)
|
||||
for k in f1.degrees:
|
||||
if k in f2.degrees:
|
||||
new_dict[k] += f2.structure_dict[k]
|
||||
structure = [(new_dict[k], k) for k in sorted(new_dict.keys())]
|
||||
return Fiber(structure=structure)
|
||||
|
||||
@staticmethod
|
||||
def combine_fibers(val1, struc1, val2, struc2):
|
||||
"""
|
||||
combine two fibers
|
||||
|
||||
:param val1/2: fiber tensors in dictionary form
|
||||
:param struc1/2: structure of fiber
|
||||
:return: fiber tensor in dictionary form
|
||||
"""
|
||||
struc_out = Fiber.combine(struc1, struc2)
|
||||
val_out = {}
|
||||
for k in struc_out.degrees:
|
||||
if k in struc1.degrees:
|
||||
if k in struc2.degrees:
|
||||
val_out[k] = torch.cat([val1[k], val2[k]], -2)
|
||||
else:
|
||||
val_out[k] = val1[k]
|
||||
else:
|
||||
val_out[k] = val2[k]
|
||||
assert val_out[k].shape[-2] == struc_out.structure_dict[k]
|
||||
return val_out
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.structure}"
|
||||
|
||||
|
||||
|
||||
def get_fiber_dict(F, struc, mask=None, return_struc=False):
|
||||
if mask is None: mask = struc
|
||||
index = 0
|
||||
fiber_dict = {}
|
||||
first_dims = F.shape[:-1]
|
||||
masked_dict = {}
|
||||
for o, m in struc.structure_dict.items():
|
||||
length = m * (2*o + 1)
|
||||
if o in mask.degrees:
|
||||
masked_dict[o] = m
|
||||
fiber_dict[o] = F[...,index:index + length].view(list(first_dims) + [m, 2*o + 1])
|
||||
index += length
|
||||
assert F.shape[-1] == index
|
||||
if return_struc:
|
||||
return fiber_dict, Fiber(dictionary=masked_dict)
|
||||
return fiber_dict
|
||||
|
||||
|
||||
def get_fiber_tensor(F, struc):
|
||||
some_entry = tuple(F.values())[0]
|
||||
first_dims = some_entry.shape[:-2]
|
||||
res = some_entry.new_empty([*first_dims, struc.n_features])
|
||||
index = 0
|
||||
for o, m in struc.structure_dict.items():
|
||||
length = m * (2*o + 1)
|
||||
res[..., index: index + length] = F[o].view(*first_dims, length)
|
||||
index += length
|
||||
assert index == res.shape[-1]
|
||||
return res
|
||||
|
||||
|
||||
def fiber2tensor(F, structure, squeeze=False):
|
||||
if squeeze:
|
||||
fibers = [F[f'{i}'].view(*F[f'{i}'].shape[:-2], -1) for i in structure.degrees]
|
||||
fibers = torch.cat(fibers, -1)
|
||||
else:
|
||||
fibers = [F[f'{i}'].view(*F[f'{i}'].shape[:-2], -1, 1) for i in structure.degrees]
|
||||
fibers = torch.cat(fibers, -2)
|
||||
return fibers
|
||||
|
||||
|
||||
def fiber2head(F, h, structure, squeeze=False):
|
||||
if squeeze:
|
||||
fibers = [F[f'{i}'].view(*F[f'{i}'].shape[:-2], h, -1) for i in structure.degrees]
|
||||
fibers = torch.cat(fibers, -1)
|
||||
else:
|
||||
fibers = [F[f'{i}'].view(*F[f'{i}'].shape[:-2], h, -1, 1) for i in structure.degrees]
|
||||
fibers = torch.cat(fibers, -2)
|
||||
return fibers
|
||||
|
|
@ -0,0 +1,289 @@
|
|||
# pylint: disable=C,E1101,E1102
|
||||
'''
|
||||
Some functions related to SO3 and his usual representations
|
||||
|
||||
Using ZYZ Euler angles parametrisation
|
||||
'''
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
class torch_default_dtype:
|
||||
|
||||
def __init__(self, dtype):
|
||||
self.saved_dtype = None
|
||||
self.dtype = dtype
|
||||
|
||||
def __enter__(self):
|
||||
self.saved_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(self.dtype)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
torch.set_default_dtype(self.saved_dtype)
|
||||
|
||||
|
||||
def rot_z(gamma):
|
||||
'''
|
||||
Rotation around Z axis
|
||||
'''
|
||||
if not torch.is_tensor(gamma):
|
||||
gamma = torch.tensor(gamma, dtype=torch.get_default_dtype())
|
||||
return torch.tensor([
|
||||
[torch.cos(gamma), -torch.sin(gamma), 0],
|
||||
[torch.sin(gamma), torch.cos(gamma), 0],
|
||||
[0, 0, 1]
|
||||
], dtype=gamma.dtype)
|
||||
|
||||
|
||||
def rot_y(beta):
|
||||
'''
|
||||
Rotation around Y axis
|
||||
'''
|
||||
if not torch.is_tensor(beta):
|
||||
beta = torch.tensor(beta, dtype=torch.get_default_dtype())
|
||||
return torch.tensor([
|
||||
[torch.cos(beta), 0, torch.sin(beta)],
|
||||
[0, 1, 0],
|
||||
[-torch.sin(beta), 0, torch.cos(beta)]
|
||||
], dtype=beta.dtype)
|
||||
|
||||
|
||||
def rot(alpha, beta, gamma):
|
||||
'''
|
||||
ZYZ Eurler angles rotation
|
||||
'''
|
||||
return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)
|
||||
|
||||
|
||||
def x_to_alpha_beta(x):
|
||||
'''
|
||||
Convert point (x, y, z) on the sphere into (alpha, beta)
|
||||
'''
|
||||
if not torch.is_tensor(x):
|
||||
x = torch.tensor(x, dtype=torch.get_default_dtype())
|
||||
x = x / torch.norm(x)
|
||||
beta = torch.acos(x[2])
|
||||
alpha = torch.atan2(x[1], x[0])
|
||||
return (alpha, beta)
|
||||
|
||||
|
||||
# These functions (x_to_alpha_beta and rot) satisfies that
|
||||
# rot(*x_to_alpha_beta([x, y, z]), 0) @ np.array([[0], [0], [1]])
|
||||
# is proportional to
|
||||
# [x, y, z]
|
||||
|
||||
|
||||
def irr_repr(order, alpha, beta, gamma, dtype=None):
|
||||
"""
|
||||
irreducible representation of SO3
|
||||
- compatible with compose and spherical_harmonics
|
||||
"""
|
||||
# from from_lielearn_SO3.wigner_d import wigner_D_matrix
|
||||
from lie_learn.representations.SO3.wigner_d import wigner_D_matrix
|
||||
# if order == 1:
|
||||
# # change of basis to have vector_field[x, y, z] = [vx, vy, vz]
|
||||
# A = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
|
||||
# return A @ wigner_D_matrix(1, alpha, beta, gamma) @ A.T
|
||||
|
||||
# TODO (non-essential): try to do everything in torch
|
||||
# return torch.tensor(wigner_D_matrix(torch.tensor(order), alpha, beta, gamma), dtype=torch.get_default_dtype() if dtype is None else dtype)
|
||||
return torch.tensor(wigner_D_matrix(order, np.array(alpha), np.array(beta), np.array(gamma)), dtype=torch.get_default_dtype() if dtype is None else dtype)
|
||||
|
||||
|
||||
# def spherical_harmonics(order, alpha, beta, dtype=None):
|
||||
# """
|
||||
# spherical harmonics
|
||||
# - compatible with irr_repr and compose
|
||||
# """
|
||||
# # from from_lielearn_SO3.spherical_harmonics import sh
|
||||
# from lie_learn.representations.SO3.spherical_harmonics import sh # real valued by default
|
||||
#
|
||||
# ###################################################################################################################
|
||||
# # ON ANGLE CONVENTION
|
||||
# #
|
||||
# # sh has following convention for angles:
|
||||
# # :param theta: the colatitude / polar angle, ranging from 0(North Pole, (X, Y, Z) = (0, 0, 1)) to pi(South Pole, (X, Y, Z) = (0, 0, -1)).
|
||||
# # :param phi: the longitude / azimuthal angle, ranging from 0 to 2 pi.
|
||||
# #
|
||||
# # this function therefore (probably) has the following convention for alpha and beta:
|
||||
# # beta = pi - theta; ranging from 0(South Pole, (X, Y, Z) = (0, 0, -1)) to pi(North Pole, (X, Y, Z) = (0, 0, 1)).
|
||||
# # alpha = phi
|
||||
# #
|
||||
# ###################################################################################################################
|
||||
#
|
||||
# Y = torch.tensor([sh(order, m, theta=math.pi - beta, phi=alpha) for m in range(-order, order + 1)], dtype=torch.get_default_dtype() if dtype is None else dtype)
|
||||
# # if order == 1:
|
||||
# # # change of basis to have vector_field[x, y, z] = [vx, vy, vz]
|
||||
# # A = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
|
||||
# # return A @ Y
|
||||
# return Y
|
||||
|
||||
|
||||
def compose(a1, b1, c1, a2, b2, c2):
|
||||
"""
|
||||
(a, b, c) = (a1, b1, c1) composed with (a2, b2, c2)
|
||||
"""
|
||||
comp = rot(a1, b1, c1) @ rot(a2, b2, c2)
|
||||
xyz = comp @ torch.tensor([0, 0, 1.])
|
||||
a, b = x_to_alpha_beta(xyz)
|
||||
rotz = rot(0, -b, -a) @ comp
|
||||
c = torch.atan2(rotz[1, 0], rotz[0, 0])
|
||||
return a, b, c
|
||||
|
||||
|
||||
def kron(x, y):
|
||||
assert x.ndimension() == 2
|
||||
assert y.ndimension() == 2
|
||||
return torch.einsum("ij,kl->ikjl", (x, y)).view(x.size(0) * y.size(0), x.size(1) * y.size(1))
|
||||
|
||||
|
||||
################################################################################
|
||||
# Change of basis
|
||||
################################################################################
|
||||
|
||||
|
||||
def xyz_vector_basis_to_spherical_basis():
|
||||
"""
|
||||
to convert a vector [x, y, z] transforming with rot(a, b, c)
|
||||
into a vector transforming with irr_repr(1, a, b, c)
|
||||
see assert for usage
|
||||
"""
|
||||
with torch_default_dtype(torch.float64):
|
||||
A = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float64)
|
||||
assert all(torch.allclose(irr_repr(1, a, b, c) @ A, A @ rot(a, b, c)) for a, b, c in torch.rand(10, 3))
|
||||
return A.type(torch.get_default_dtype())
|
||||
|
||||
|
||||
def tensor3x3_repr(a, b, c):
|
||||
"""
|
||||
representation of 3x3 tensors
|
||||
T --> R T R^t
|
||||
"""
|
||||
r = rot(a, b, c)
|
||||
return kron(r, r)
|
||||
|
||||
|
||||
def tensor3x3_repr_basis_to_spherical_basis():
|
||||
"""
|
||||
to convert a 3x3 tensor transforming with tensor3x3_repr(a, b, c)
|
||||
into its 1 + 3 + 5 component transforming with irr_repr(0, a, b, c), irr_repr(1, a, b, c), irr_repr(3, a, b, c)
|
||||
see assert for usage
|
||||
"""
|
||||
with torch_default_dtype(torch.float64):
|
||||
to1 = torch.tensor([
|
||||
[1, 0, 0, 0, 1, 0, 0, 0, 1],
|
||||
], dtype=torch.get_default_dtype())
|
||||
assert all(torch.allclose(irr_repr(0, a, b, c) @ to1, to1 @ tensor3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3))
|
||||
|
||||
to3 = torch.tensor([
|
||||
[0, 0, -1, 0, 0, 0, 1, 0, 0],
|
||||
[0, 1, 0, -1, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 1, 0, -1, 0],
|
||||
], dtype=torch.get_default_dtype())
|
||||
assert all(torch.allclose(irr_repr(1, a, b, c) @ to3, to3 @ tensor3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3))
|
||||
|
||||
to5 = torch.tensor([
|
||||
[0, 1, 0, 1, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 1, 0, 1, 0],
|
||||
[-3**.5/3, 0, 0, 0, -3**.5/3, 0, 0, 0, 12**.5/3],
|
||||
[0, 0, 1, 0, 0, 0, 1, 0, 0],
|
||||
[1, 0, 0, 0, -1, 0, 0, 0, 0]
|
||||
], dtype=torch.get_default_dtype())
|
||||
assert all(torch.allclose(irr_repr(2, a, b, c) @ to5, to5 @ tensor3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3))
|
||||
|
||||
return to1.type(torch.get_default_dtype()), to3.type(torch.get_default_dtype()), to5.type(torch.get_default_dtype())
|
||||
|
||||
|
||||
################################################################################
|
||||
# Tests
|
||||
################################################################################
|
||||
|
||||
|
||||
def test_is_representation(rep):
|
||||
"""
|
||||
rep(Z(a1) Y(b1) Z(c1) Z(a2) Y(b2) Z(c2)) = rep(Z(a1) Y(b1) Z(c1)) rep(Z(a2) Y(b2) Z(c2))
|
||||
"""
|
||||
with torch_default_dtype(torch.float64):
|
||||
a1, b1, c1, a2, b2, c2 = torch.rand(6)
|
||||
|
||||
r1 = rep(a1, b1, c1)
|
||||
r2 = rep(a2, b2, c2)
|
||||
|
||||
a, b, c = compose(a1, b1, c1, a2, b2, c2)
|
||||
r = rep(a, b, c)
|
||||
|
||||
r_ = r1 @ r2
|
||||
|
||||
d, r = (r - r_).abs().max(), r.abs().max()
|
||||
print(d.item(), r.item())
|
||||
assert d < 1e-10 * r, d / r
|
||||
|
||||
|
||||
def _test_spherical_harmonics(order):
|
||||
"""
|
||||
This test tests that
|
||||
- irr_repr
|
||||
- compose
|
||||
- spherical_harmonics
|
||||
are compatible
|
||||
|
||||
Y(Z(alpha) Y(beta) Z(gamma) x) = D(alpha, beta, gamma) Y(x)
|
||||
with x = Z(a) Y(b) eta
|
||||
"""
|
||||
with torch_default_dtype(torch.float64):
|
||||
a, b = torch.rand(2)
|
||||
alpha, beta, gamma = torch.rand(3)
|
||||
|
||||
ra, rb, _ = compose(alpha, beta, gamma, a, b, 0)
|
||||
Yrx = spherical_harmonics(order, ra, rb)
|
||||
|
||||
Y = spherical_harmonics(order, a, b)
|
||||
DrY = irr_repr(order, alpha, beta, gamma) @ Y
|
||||
|
||||
d, r = (Yrx - DrY).abs().max(), Y.abs().max()
|
||||
print(d.item(), r.item())
|
||||
assert d < 1e-10 * r, d / r
|
||||
|
||||
|
||||
def _test_change_basis_wigner_to_rot():
|
||||
# from from_lielearn_SO3.wigner_d import wigner_D_matrix
|
||||
from lie_learn.representations.SO3.wigner_d import wigner_D_matrix
|
||||
|
||||
with torch_default_dtype(torch.float64):
|
||||
A = torch.tensor([
|
||||
[0, 1, 0],
|
||||
[0, 0, 1],
|
||||
[1, 0, 0]
|
||||
], dtype=torch.float64)
|
||||
|
||||
a, b, c = torch.rand(3)
|
||||
|
||||
r1 = A.t() @ torch.tensor(wigner_D_matrix(1, a, b, c), dtype=torch.float64) @ A
|
||||
r2 = rot(a, b, c)
|
||||
|
||||
d = (r1 - r2).abs().max()
|
||||
print(d.item())
|
||||
assert d < 1e-10
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from functools import partial
|
||||
|
||||
print("Change of basis")
|
||||
xyz_vector_basis_to_spherical_basis()
|
||||
test_is_representation(tensor3x3_repr)
|
||||
tensor3x3_repr_basis_to_spherical_basis()
|
||||
|
||||
print("Change of basis Wigner <-> rot")
|
||||
_test_change_basis_wigner_to_rot()
|
||||
_test_change_basis_wigner_to_rot()
|
||||
_test_change_basis_wigner_to_rot()
|
||||
|
||||
print("Spherical harmonics are solution of Y(rx) = D(r) Y(x)")
|
||||
for l in range(7):
|
||||
_test_spherical_harmonics(l)
|
||||
|
||||
print("Irreducible repr are indeed representations")
|
||||
for l in range(7):
|
||||
test_is_representation(partial(irr_repr, l))
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1 @@
|
|||
3239778
|
|
@ -0,0 +1,112 @@
|
|||
'''
|
||||
Cache in files
|
||||
'''
|
||||
from functools import wraps, lru_cache
|
||||
import pickle
|
||||
import gzip
|
||||
import os
|
||||
import sys
|
||||
import fcntl
|
||||
|
||||
|
||||
class FileSystemMutex:
|
||||
'''
|
||||
Mutual exclusion of different **processes** using the file system
|
||||
'''
|
||||
|
||||
def __init__(self, filename):
|
||||
self.handle = None
|
||||
self.filename = filename
|
||||
|
||||
def acquire(self):
|
||||
'''
|
||||
Locks the mutex
|
||||
if it is already locked, it waits (blocking function)
|
||||
'''
|
||||
self.handle = open(self.filename, 'w')
|
||||
fcntl.lockf(self.handle, fcntl.LOCK_EX)
|
||||
self.handle.write("{}\n".format(os.getpid()))
|
||||
self.handle.flush()
|
||||
|
||||
def release(self):
|
||||
'''
|
||||
Unlock the mutex
|
||||
'''
|
||||
if self.handle is None:
|
||||
raise RuntimeError()
|
||||
fcntl.lockf(self.handle, fcntl.LOCK_UN)
|
||||
self.handle.close()
|
||||
self.handle = None
|
||||
|
||||
def __enter__(self):
|
||||
self.acquire()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.release()
|
||||
|
||||
|
||||
def cached_dirpklgz(dirname, maxsize=128):
|
||||
'''
|
||||
Cache a function with a directory
|
||||
|
||||
:param dirname: the directory path
|
||||
:param maxsize: maximum size of the RAM cache (there is no limit for the directory cache)
|
||||
'''
|
||||
|
||||
def decorator(func):
|
||||
'''
|
||||
The actual decorator
|
||||
'''
|
||||
|
||||
@lru_cache(maxsize=maxsize)
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
'''
|
||||
The wrapper of the function
|
||||
'''
|
||||
try:
|
||||
os.makedirs(dirname)
|
||||
except FileExistsError:
|
||||
pass
|
||||
|
||||
indexfile = os.path.join(dirname, "index.pkl")
|
||||
mutexfile = "cache_mutex"
|
||||
#mutexfile = os.path.join(dirname, "mutex")
|
||||
|
||||
with FileSystemMutex(mutexfile):
|
||||
try:
|
||||
with open(indexfile, "rb") as file:
|
||||
index = pickle.load(file)
|
||||
except FileNotFoundError:
|
||||
index = {}
|
||||
|
||||
key = (args, frozenset(kwargs), func.__defaults__)
|
||||
|
||||
try:
|
||||
filename = index[key]
|
||||
except KeyError:
|
||||
index[key] = filename = "{}.pkl.gz".format(len(index))
|
||||
with open(indexfile, "wb") as file:
|
||||
pickle.dump(index, file)
|
||||
|
||||
filepath = os.path.join(dirname, filename)
|
||||
|
||||
try:
|
||||
with FileSystemMutex(mutexfile):
|
||||
with gzip.open(filepath, "rb") as file:
|
||||
result = pickle.load(file)
|
||||
except FileNotFoundError:
|
||||
print("compute {}... ".format(filename), end="")
|
||||
sys.stdout.flush()
|
||||
result = func(*args, **kwargs)
|
||||
print("save {}... ".format(filename), end="")
|
||||
sys.stdout.flush()
|
||||
with FileSystemMutex(mutexfile):
|
||||
with gzip.open(filepath, "wb") as file:
|
||||
pickle.dump(result, file)
|
||||
print("done")
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
|
@ -0,0 +1,25 @@
|
|||
the code in this folder was mostly obtained from https://github.com/mariogeiger/se3cnn/
|
||||
|
||||
which has the following license:
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2019 Mario Geiger
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,249 @@
|
|||
import time
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from scipy.special import lpmv as lpmv_scipy
|
||||
|
||||
|
||||
def semifactorial(x):
|
||||
"""Compute the semifactorial function x!!.
|
||||
|
||||
x!! = x * (x-2) * (x-4) *...
|
||||
|
||||
Args:
|
||||
x: positive int
|
||||
Returns:
|
||||
float for x!!
|
||||
"""
|
||||
y = 1.
|
||||
for n in range(x, 1, -2):
|
||||
y *= n
|
||||
return y
|
||||
|
||||
|
||||
def pochhammer(x, k):
|
||||
"""Compute the pochhammer symbol (x)_k.
|
||||
|
||||
(x)_k = x * (x+1) * (x+2) *...* (x+k-1)
|
||||
|
||||
Args:
|
||||
x: positive int
|
||||
Returns:
|
||||
float for (x)_k
|
||||
"""
|
||||
xf = float(x)
|
||||
for n in range(x+1, x+k):
|
||||
xf *= n
|
||||
return xf
|
||||
|
||||
def lpmv(l, m, x):
|
||||
"""Associated Legendre function including Condon-Shortley phase.
|
||||
|
||||
Args:
|
||||
m: int order
|
||||
l: int degree
|
||||
x: float argument tensor
|
||||
Returns:
|
||||
tensor of x-shape
|
||||
"""
|
||||
m_abs = abs(m)
|
||||
if m_abs > l:
|
||||
return torch.zeros_like(x)
|
||||
|
||||
# Compute P_m^m
|
||||
yold = ((-1)**m_abs * semifactorial(2*m_abs-1)) * torch.pow(1-x*x, m_abs/2)
|
||||
|
||||
# Compute P_{m+1}^m
|
||||
if m_abs != l:
|
||||
y = x * (2*m_abs+1) * yold
|
||||
else:
|
||||
y = yold
|
||||
|
||||
# Compute P_{l}^m from recursion in P_{l-1}^m and P_{l-2}^m
|
||||
for i in range(m_abs+2, l+1):
|
||||
tmp = y
|
||||
# Inplace speedup
|
||||
y = ((2*i-1) / (i-m_abs)) * x * y
|
||||
y -= ((i+m_abs-1)/(i-m_abs)) * yold
|
||||
yold = tmp
|
||||
|
||||
if m < 0:
|
||||
y *= ((-1)**m / pochhammer(l+m+1, -2*m))
|
||||
|
||||
return y
|
||||
|
||||
def tesseral_harmonics(l, m, theta=0., phi=0.):
|
||||
"""Tesseral spherical harmonic with Condon-Shortley phase.
|
||||
|
||||
The Tesseral spherical harmonics are also known as the real spherical
|
||||
harmonics.
|
||||
|
||||
Args:
|
||||
l: int for degree
|
||||
m: int for order, where -l <= m < l
|
||||
theta: collatitude or polar angle
|
||||
phi: longitude or azimuth
|
||||
Returns:
|
||||
tensor of shape theta
|
||||
"""
|
||||
assert abs(m) <= l, "absolute value of order m must be <= degree l"
|
||||
|
||||
N = np.sqrt((2*l+1) / (4*np.pi))
|
||||
leg = lpmv(l, abs(m), torch.cos(theta))
|
||||
if m == 0:
|
||||
return N*leg
|
||||
elif m > 0:
|
||||
Y = torch.cos(m*phi) * leg
|
||||
else:
|
||||
Y = torch.sin(abs(m)*phi) * leg
|
||||
N *= np.sqrt(2. / pochhammer(l-abs(m)+1, 2*abs(m)))
|
||||
Y *= N
|
||||
return Y
|
||||
|
||||
class SphericalHarmonics(object):
|
||||
def __init__(self):
|
||||
self.leg = {}
|
||||
|
||||
def clear(self):
|
||||
self.leg = {}
|
||||
|
||||
def negative_lpmv(self, l, m, y):
|
||||
"""Compute negative order coefficients"""
|
||||
if m < 0:
|
||||
y *= ((-1)**m / pochhammer(l+m+1, -2*m))
|
||||
return y
|
||||
|
||||
def lpmv(self, l, m, x):
|
||||
"""Associated Legendre function including Condon-Shortley phase.
|
||||
|
||||
Args:
|
||||
m: int order
|
||||
l: int degree
|
||||
x: float argument tensor
|
||||
Returns:
|
||||
tensor of x-shape
|
||||
"""
|
||||
# Check memoized versions
|
||||
m_abs = abs(m)
|
||||
if (l,m) in self.leg:
|
||||
return self.leg[(l,m)]
|
||||
elif m_abs > l:
|
||||
return None
|
||||
elif l == 0:
|
||||
self.leg[(l,m)] = torch.ones_like(x)
|
||||
return self.leg[(l,m)]
|
||||
|
||||
# Check if on boundary else recurse solution down to boundary
|
||||
if m_abs == l:
|
||||
# Compute P_m^m
|
||||
y = (-1)**m_abs * semifactorial(2*m_abs-1)
|
||||
y *= torch.pow(1-x*x, m_abs/2)
|
||||
self.leg[(l,m)] = self.negative_lpmv(l, m, y)
|
||||
return self.leg[(l,m)]
|
||||
else:
|
||||
# Recursively precompute lower degree harmonics
|
||||
self.lpmv(l-1, m, x)
|
||||
|
||||
# Compute P_{l}^m from recursion in P_{l-1}^m and P_{l-2}^m
|
||||
# Inplace speedup
|
||||
y = ((2*l-1) / (l-m_abs)) * x * self.lpmv(l-1, m_abs, x)
|
||||
if l - m_abs > 1:
|
||||
y -= ((l+m_abs-1)/(l-m_abs)) * self.leg[(l-2, m_abs)]
|
||||
#self.leg[(l, m_abs)] = y
|
||||
|
||||
if m < 0:
|
||||
y = self.negative_lpmv(l, m, y)
|
||||
self.leg[(l,m)] = y
|
||||
|
||||
return self.leg[(l,m)]
|
||||
|
||||
def get_element(self, l, m, theta, phi):
|
||||
"""Tesseral spherical harmonic with Condon-Shortley phase.
|
||||
|
||||
The Tesseral spherical harmonics are also known as the real spherical
|
||||
harmonics.
|
||||
|
||||
Args:
|
||||
l: int for degree
|
||||
m: int for order, where -l <= m < l
|
||||
theta: collatitude or polar angle
|
||||
phi: longitude or azimuth
|
||||
Returns:
|
||||
tensor of shape theta
|
||||
"""
|
||||
assert abs(m) <= l, "absolute value of order m must be <= degree l"
|
||||
|
||||
N = np.sqrt((2*l+1) / (4*np.pi))
|
||||
leg = self.lpmv(l, abs(m), torch.cos(theta))
|
||||
if m == 0:
|
||||
return N*leg
|
||||
elif m > 0:
|
||||
Y = torch.cos(m*phi) * leg
|
||||
else:
|
||||
Y = torch.sin(abs(m)*phi) * leg
|
||||
N *= np.sqrt(2. / pochhammer(l-abs(m)+1, 2*abs(m)))
|
||||
Y *= N
|
||||
return Y
|
||||
|
||||
def get(self, l, theta, phi, refresh=True):
|
||||
"""Tesseral harmonic with Condon-Shortley phase.
|
||||
|
||||
The Tesseral spherical harmonics are also known as the real spherical
|
||||
harmonics.
|
||||
|
||||
Args:
|
||||
l: int for degree
|
||||
theta: collatitude or polar angle
|
||||
phi: longitude or azimuth
|
||||
Returns:
|
||||
tensor of shape [*theta.shape, 2*l+1]
|
||||
"""
|
||||
results = []
|
||||
if refresh:
|
||||
self.clear()
|
||||
for m in range(-l, l+1):
|
||||
results.append(self.get_element(l, m, theta, phi))
|
||||
return torch.stack(results, -1)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lie_learn.representations.SO3.spherical_harmonics import sh
|
||||
device = 'cuda'
|
||||
dtype = torch.float64
|
||||
bs = 32
|
||||
theta = 0.1*torch.randn(bs,1024,10, dtype=dtype)
|
||||
phi = 0.1*torch.randn(bs,1024,10, dtype=dtype)
|
||||
cu_theta = theta.to(device)
|
||||
cu_phi = phi.to(device)
|
||||
s0 = s1 = s2 = 0
|
||||
max_error = -1.
|
||||
|
||||
sph_har = SphericalHarmonics()
|
||||
for l in range(10):
|
||||
for m in range(l, -l-1, -1):
|
||||
start = time.time()
|
||||
#y = tesseral_harmonics(l, m, theta, phi)
|
||||
y = sph_har.get_element(l, m, cu_theta, cu_phi).type(torch.float32)
|
||||
#y = sph_har.lpmv(l, m, phi)
|
||||
s0 += time.time() - start
|
||||
start = time.time()
|
||||
z = sh(l, m, theta, phi)
|
||||
#z = lpmv_scipy(m, l, phi).numpy()
|
||||
s1 += time.time() - start
|
||||
|
||||
error = np.mean(np.abs((y.cpu().numpy() - z) / z))
|
||||
max_error = max(max_error, error)
|
||||
print(f"l: {l}, m: {m} ", error)
|
||||
|
||||
#start = time.time()
|
||||
#sph_har.get(l, theta, phi)
|
||||
#s2 += time.time() - start
|
||||
|
||||
print('#################')
|
||||
|
||||
print(f"Max error: {max_error}")
|
||||
print(f"Time diff: {s0/s1}")
|
||||
print(f"Total time: {s0}")
|
||||
#print(f"Time diff: {s2/s1}")
|
|
@ -0,0 +1,326 @@
|
|||
import os
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
from equivariant_attention.from_se3cnn.SO3 import irr_repr, torch_default_dtype
|
||||
from equivariant_attention.from_se3cnn.cache_file import cached_dirpklgz
|
||||
from equivariant_attention.from_se3cnn.representations import SphericalHarmonics
|
||||
|
||||
################################################################################
|
||||
# Solving the constraint coming from the stabilizer of 0 and e
|
||||
################################################################################
|
||||
|
||||
def get_matrix_kernel(A, eps=1e-10):
|
||||
'''
|
||||
Compute an orthonormal basis of the kernel (x_1, x_2, ...)
|
||||
A x_i = 0
|
||||
scalar_product(x_i, x_j) = delta_ij
|
||||
|
||||
:param A: matrix
|
||||
:return: matrix where each row is a basis vector of the kernel of A
|
||||
'''
|
||||
_u, s, v = torch.svd(A)
|
||||
|
||||
# A = u @ torch.diag(s) @ v.t()
|
||||
kernel = v.t()[s < eps]
|
||||
return kernel
|
||||
|
||||
|
||||
def get_matrices_kernel(As, eps=1e-10):
|
||||
'''
|
||||
Computes the commun kernel of all the As matrices
|
||||
'''
|
||||
return get_matrix_kernel(torch.cat(As, dim=0), eps)
|
||||
|
||||
|
||||
@cached_dirpklgz("%s/cache/trans_Q"%os.path.dirname(os.path.realpath(__file__)))
|
||||
def _basis_transformation_Q_J(J, order_in, order_out, version=3): # pylint: disable=W0613
|
||||
"""
|
||||
:param J: order of the spherical harmonics
|
||||
:param order_in: order of the input representation
|
||||
:param order_out: order of the output representation
|
||||
:return: one part of the Q^-1 matrix of the article
|
||||
"""
|
||||
with torch_default_dtype(torch.float64):
|
||||
def _R_tensor(a, b, c): return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c))
|
||||
|
||||
def _sylvester_submatrix(J, a, b, c):
|
||||
''' generate Kronecker product matrix for solving the Sylvester equation in subspace J '''
|
||||
R_tensor = _R_tensor(a, b, c) # [m_out * m_in, m_out * m_in]
|
||||
R_irrep_J = irr_repr(J, a, b, c) # [m, m]
|
||||
return kron(R_tensor, torch.eye(R_irrep_J.size(0))) - \
|
||||
kron(torch.eye(R_tensor.size(0)), R_irrep_J.t()) # [(m_out * m_in) * m, (m_out * m_in) * m]
|
||||
|
||||
random_angles = [
|
||||
[4.41301023, 5.56684102, 4.59384642],
|
||||
[4.93325116, 6.12697327, 4.14574096],
|
||||
[0.53878964, 4.09050444, 5.36539036],
|
||||
[2.16017393, 3.48835314, 5.55174441],
|
||||
[2.52385107, 0.2908958, 3.90040975]
|
||||
]
|
||||
null_space = get_matrices_kernel([_sylvester_submatrix(J, a, b, c) for a, b, c in random_angles])
|
||||
assert null_space.size(0) == 1, null_space.size() # unique subspace solution
|
||||
Q_J = null_space[0] # [(m_out * m_in) * m]
|
||||
Q_J = Q_J.view((2 * order_out + 1) * (2 * order_in + 1), 2 * J + 1) # [m_out * m_in, m]
|
||||
assert all(torch.allclose(_R_tensor(a, b, c) @ Q_J, Q_J @ irr_repr(J, a, b, c)) for a, b, c in torch.rand(4, 3))
|
||||
|
||||
assert Q_J.dtype == torch.float64
|
||||
return Q_J # [m_out * m_in, m]
|
||||
|
||||
|
||||
def get_spherical_from_cartesian_torch(cartesian, divide_radius_by=1.0):
|
||||
|
||||
###################################################################################################################
|
||||
# ON ANGLE CONVENTION
|
||||
#
|
||||
# sh has following convention for angles:
|
||||
# :param theta: the colatitude / polar angle, ranging from 0(North Pole, (X, Y, Z) = (0, 0, 1)) to pi(South Pole, (X, Y, Z) = (0, 0, -1)).
|
||||
# :param phi: the longitude / azimuthal angle, ranging from 0 to 2 pi.
|
||||
#
|
||||
# the 3D steerable CNN code therefore (probably) has the following convention for alpha and beta:
|
||||
# beta = pi - theta; ranging from 0(South Pole, (X, Y, Z) = (0, 0, -1)) to pi(North Pole, (X, Y, Z) = (0, 0, 1)).
|
||||
# alpha = phi
|
||||
#
|
||||
###################################################################################################################
|
||||
|
||||
# initialise return array
|
||||
# ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
|
||||
spherical = torch.zeros_like(cartesian)
|
||||
|
||||
# indices for return array
|
||||
ind_radius = 0
|
||||
ind_alpha = 1
|
||||
ind_beta = 2
|
||||
|
||||
cartesian_x = 2
|
||||
cartesian_y = 0
|
||||
cartesian_z = 1
|
||||
|
||||
# get projected radius in xy plane
|
||||
# xy = xyz[:,0]**2 + xyz[:,1]**2
|
||||
r_xy = cartesian[..., cartesian_x] ** 2 + cartesian[..., cartesian_y] ** 2
|
||||
|
||||
# get second angle
|
||||
# version 'elevation angle defined from Z-axis down'
|
||||
spherical[..., ind_beta] = torch.atan2(torch.sqrt(r_xy), cartesian[..., cartesian_z])
|
||||
# ptsnew[:,4] = np.arctan2(np.sqrt(xy), xyz[:,2])
|
||||
# version 'elevation angle defined from XY-plane up'
|
||||
#ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy))
|
||||
# spherical[:, ind_beta] = np.arctan2(cartesian[:, 2], np.sqrt(r_xy))
|
||||
|
||||
# get angle in x-y plane
|
||||
spherical[...,ind_alpha] = torch.atan2(cartesian[...,cartesian_y], cartesian[...,cartesian_x])
|
||||
|
||||
# get overall radius
|
||||
# ptsnew[:,3] = np.sqrt(xy + xyz[:,2]**2)
|
||||
if divide_radius_by == 1.0:
|
||||
spherical[..., ind_radius] = torch.sqrt(r_xy + cartesian[...,cartesian_z]**2)
|
||||
else:
|
||||
spherical[..., ind_radius] = torch.sqrt(r_xy + cartesian[...,cartesian_z]**2)/divide_radius_by
|
||||
|
||||
return spherical
|
||||
|
||||
|
||||
def get_spherical_from_cartesian(cartesian):
|
||||
|
||||
###################################################################################################################
|
||||
# ON ANGLE CONVENTION
|
||||
#
|
||||
# sh has following convention for angles:
|
||||
# :param theta: the colatitude / polar angle, ranging from 0(North Pole, (X, Y, Z) = (0, 0, 1)) to pi(South Pole, (X, Y, Z) = (0, 0, -1)).
|
||||
# :param phi: the longitude / azimuthal angle, ranging from 0 to 2 pi.
|
||||
#
|
||||
# the 3D steerable CNN code therefore (probably) has the following convention for alpha and beta:
|
||||
# beta = pi - theta; ranging from 0(South Pole, (X, Y, Z) = (0, 0, -1)) to pi(North Pole, (X, Y, Z) = (0, 0, 1)).
|
||||
# alpha = phi
|
||||
#
|
||||
###################################################################################################################
|
||||
|
||||
if torch.is_tensor(cartesian):
|
||||
cartesian = np.array(cartesian.cpu())
|
||||
|
||||
# initialise return array
|
||||
# ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
|
||||
spherical = np.zeros(cartesian.shape)
|
||||
|
||||
# indices for return array
|
||||
ind_radius = 0
|
||||
ind_alpha = 1
|
||||
ind_beta = 2
|
||||
|
||||
cartesian_x = 2
|
||||
cartesian_y = 0
|
||||
cartesian_z = 1
|
||||
|
||||
# get projected radius in xy plane
|
||||
# xy = xyz[:,0]**2 + xyz[:,1]**2
|
||||
r_xy = cartesian[..., cartesian_x] ** 2 + cartesian[..., cartesian_y] ** 2
|
||||
|
||||
# get overall radius
|
||||
# ptsnew[:,3] = np.sqrt(xy + xyz[:,2]**2)
|
||||
spherical[..., ind_radius] = np.sqrt(r_xy + cartesian[...,cartesian_z]**2)
|
||||
|
||||
# get second angle
|
||||
# version 'elevation angle defined from Z-axis down'
|
||||
spherical[..., ind_beta] = np.arctan2(np.sqrt(r_xy), cartesian[..., cartesian_z])
|
||||
# ptsnew[:,4] = np.arctan2(np.sqrt(xy), xyz[:,2])
|
||||
# version 'elevation angle defined from XY-plane up'
|
||||
#ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy))
|
||||
# spherical[:, ind_beta] = np.arctan2(cartesian[:, 2], np.sqrt(r_xy))
|
||||
|
||||
# get angle in x-y plane
|
||||
spherical[...,ind_alpha] = np.arctan2(cartesian[...,cartesian_y], cartesian[...,cartesian_x])
|
||||
|
||||
return spherical
|
||||
|
||||
def test_coordinate_conversion():
|
||||
p = np.array([0, 0, -1])
|
||||
expected = np.array([1, 0, 0])
|
||||
assert get_spherical_from_cartesian(p) == expected
|
||||
return True
|
||||
|
||||
|
||||
def spherical_harmonics(order, alpha, beta, dtype=None):
|
||||
"""
|
||||
spherical harmonics
|
||||
- compatible with irr_repr and compose
|
||||
|
||||
computation time: excecuting 1000 times with array length 1 took 0.29 seconds;
|
||||
executing it once with array of length 1000 took 0.0022 seconds
|
||||
"""
|
||||
#Y = [tesseral_harmonics(order, m, theta=math.pi - beta, phi=alpha) for m in range(-order, order + 1)]
|
||||
#Y = torch.stack(Y, -1)
|
||||
# Y should have dimension 2*order + 1
|
||||
return SphericalHarmonics.get(order, theta=math.pi-beta, phi=alpha)
|
||||
|
||||
def kron(a, b):
|
||||
"""
|
||||
A part of the pylabyk library: numpytorch.py at https://github.com/yulkang/pylabyk
|
||||
|
||||
Kronecker product of matrices a and b with leading batch dimensions.
|
||||
Batch dimensions are broadcast. The number of them mush
|
||||
:type a: torch.Tensor
|
||||
:type b: torch.Tensor
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
siz1 = torch.Size(torch.tensor(a.shape[-2:]) * torch.tensor(b.shape[-2:]))
|
||||
res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4)
|
||||
siz0 = res.shape[:-4]
|
||||
return res.reshape(siz0 + siz1)
|
||||
|
||||
|
||||
def get_maximum_order_unary_only(per_layer_orders_and_multiplicities):
|
||||
"""
|
||||
determine what spherical harmonics we need to pre-compute. if we have the
|
||||
unary term only, we need to compare all adjacent layers
|
||||
|
||||
the spherical harmonics function depends on J (irrep order) purely, which is dedfined by
|
||||
order_irreps = list(range(abs(order_in - order_out), order_in + order_out + 1))
|
||||
simplification: we only care about the maximum (in some circumstances that means we calculate a few lower
|
||||
order spherical harmonics which we won't actually need)
|
||||
|
||||
:param per_layer_orders_and_multiplicities: nested list of lists of 2-tuples
|
||||
:return: integer indicating maximum order J
|
||||
"""
|
||||
|
||||
n_layers = len(per_layer_orders_and_multiplicities)
|
||||
|
||||
# extract orders only
|
||||
per_layer_orders = []
|
||||
for i in range(n_layers):
|
||||
cur = per_layer_orders_and_multiplicities[i]
|
||||
cur = [o for (m, o) in cur]
|
||||
per_layer_orders.append(cur)
|
||||
|
||||
track_max = 0
|
||||
# compare two (adjacent) layers at a time
|
||||
for i in range(n_layers - 1):
|
||||
cur = per_layer_orders[i]
|
||||
nex = per_layer_orders[i + 1]
|
||||
track_max = max(max(cur) + max(nex), track_max)
|
||||
|
||||
return track_max
|
||||
|
||||
|
||||
def get_maximum_order_with_pairwise(per_layer_orders_and_multiplicities):
|
||||
"""
|
||||
determine what spherical harmonics we need to pre-compute. for pairwise
|
||||
interactions, this will just be twice the maximum order
|
||||
|
||||
the spherical harmonics function depends on J (irrep order) purely, which is defined by
|
||||
order_irreps = list(range(abs(order_in - order_out), order_in + order_out + 1))
|
||||
simplification: we only care about the maximum (in some circumstances that means we calculate a few lower
|
||||
order spherical harmonics which we won't actually need)
|
||||
|
||||
:param per_layer_orders_and_multiplicities: nested list of lists of 2-tuples
|
||||
:return: integer indicating maximum order J
|
||||
"""
|
||||
|
||||
n_layers = len(per_layer_orders_and_multiplicities)
|
||||
|
||||
track_max = 0
|
||||
for i in range(n_layers):
|
||||
cur = per_layer_orders_and_multiplicities[i]
|
||||
# extract orders only
|
||||
orders = [o for (m, o) in cur]
|
||||
track_max = max(track_max, max(orders))
|
||||
|
||||
return 2*track_max
|
||||
|
||||
|
||||
def precompute_sh(r_ij, max_J):
|
||||
"""
|
||||
pre-comput spherical harmonics up to order max_J
|
||||
|
||||
:param r_ij: relative positions
|
||||
:param max_J: maximum order used in entire network
|
||||
:return: dict where each entry has shape [B,N,K,2J+1]
|
||||
"""
|
||||
|
||||
i_distance = 0
|
||||
i_alpha = 1
|
||||
i_beta = 2
|
||||
|
||||
Y_Js = {}
|
||||
sh = SphericalHarmonics()
|
||||
|
||||
for J in range(max_J+1):
|
||||
# dimension [B,N,K,2J+1]
|
||||
#Y_Js[J] = spherical_harmonics(order=J, alpha=r_ij[...,i_alpha], beta=r_ij[...,i_beta])
|
||||
Y_Js[J] = sh.get(J, theta=math.pi-r_ij[...,i_beta], phi=r_ij[...,i_alpha], refresh=False)
|
||||
|
||||
sh.clear()
|
||||
return Y_Js
|
||||
|
||||
|
||||
class ScalarActivation3rdDim(torch.nn.Module):
|
||||
def __init__(self, n_dim, activation, bias=True):
|
||||
'''
|
||||
Can be used only with scalar fields [B, N, s] on last dimension
|
||||
|
||||
:param n_dim: number of scalar fields to apply activation to
|
||||
:param bool bias: add a bias before the applying the activation
|
||||
'''
|
||||
super().__init__()
|
||||
|
||||
self.activation = activation
|
||||
|
||||
if bias and n_dim > 0:
|
||||
self.bias = torch.nn.Parameter(torch.zeros(n_dim))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
'''
|
||||
:param input: [B, N, s]
|
||||
'''
|
||||
|
||||
assert len(np.array(input.shape)) == 3
|
||||
|
||||
if self.bias is not None:
|
||||
x = input + self.bias.view(1, 1, -1)
|
||||
else:
|
||||
x = input
|
||||
x = self.activation(x)
|
||||
|
||||
return x
|
905
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/equivariant_attention/modules.py
Executable file
905
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/equivariant_attention/modules.py
Executable file
|
@ -0,0 +1,905 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from contextlib import nullcontext
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from equivariant_attention.from_se3cnn import utils_steerable
|
||||
from equivariant_attention.fibers import Fiber, fiber2head
|
||||
from utils.utils_logging import log_gradient_norm
|
||||
|
||||
import dgl
|
||||
import dgl.function as fn
|
||||
from dgl.nn.pytorch.softmax import edge_softmax
|
||||
from dgl.nn.pytorch.glob import AvgPooling, MaxPooling
|
||||
|
||||
from packaging import version
|
||||
|
||||
|
||||
def get_basis(G, max_degree, compute_gradients):
|
||||
"""Precompute the SE(3)-equivariant weight basis, W_J^lk(x)
|
||||
|
||||
This is called by get_basis_and_r().
|
||||
|
||||
Args:
|
||||
G: DGL graph instance of type dgl.DGLGraph
|
||||
max_degree: non-negative int for degree of highest feature type
|
||||
compute_gradients: boolean, whether to compute gradients during basis construction
|
||||
Returns:
|
||||
dict of equivariant bases. Keys are in the form 'd_in,d_out'. Values are
|
||||
tensors of shape (batch_size, 1, 2*d_out+1, 1, 2*d_in+1, number_of_bases)
|
||||
where the 1's will later be broadcast to the number of output and input
|
||||
channels
|
||||
"""
|
||||
if compute_gradients:
|
||||
context = nullcontext()
|
||||
else:
|
||||
context = torch.no_grad()
|
||||
|
||||
with context:
|
||||
cloned_d = torch.clone(G.edata['d'])
|
||||
|
||||
if G.edata['d'].requires_grad:
|
||||
cloned_d.requires_grad_()
|
||||
log_gradient_norm(cloned_d, 'Basis computation flow')
|
||||
|
||||
# Relative positional encodings (vector)
|
||||
r_ij = utils_steerable.get_spherical_from_cartesian_torch(cloned_d)
|
||||
# Spherical harmonic basis
|
||||
Y = utils_steerable.precompute_sh(r_ij, 2*max_degree)
|
||||
device = Y[0].device
|
||||
|
||||
basis = {}
|
||||
for d_in in range(max_degree+1):
|
||||
for d_out in range(max_degree+1):
|
||||
K_Js = []
|
||||
for J in range(abs(d_in-d_out), d_in+d_out+1):
|
||||
# Get spherical harmonic projection matrices
|
||||
Q_J = utils_steerable._basis_transformation_Q_J(J, d_in, d_out)
|
||||
Q_J = Q_J.float().T.to(device)
|
||||
|
||||
# Create kernel from spherical harmonics
|
||||
K_J = torch.matmul(Y[J], Q_J)
|
||||
K_Js.append(K_J)
|
||||
|
||||
# Reshape so can take linear combinations with a dot product
|
||||
size = (-1, 1, 2*d_out+1, 1, 2*d_in+1, 2*min(d_in, d_out)+1)
|
||||
basis[f'{d_in},{d_out}'] = torch.stack(K_Js, -1).view(*size)
|
||||
return basis
|
||||
|
||||
|
||||
def get_r(G):
|
||||
"""Compute internodal distances"""
|
||||
cloned_d = torch.clone(G.edata['d'])
|
||||
|
||||
if G.edata['d'].requires_grad:
|
||||
cloned_d.requires_grad_()
|
||||
log_gradient_norm(cloned_d, 'Neural networks flow')
|
||||
|
||||
return torch.sqrt(torch.sum(cloned_d**2, -1, keepdim=True))
|
||||
|
||||
|
||||
def get_basis_and_r(G, max_degree, compute_gradients=False):
|
||||
"""Return equivariant weight basis (basis) and internodal distances (r).
|
||||
|
||||
Call this function *once* at the start of each forward pass of the model.
|
||||
It computes the equivariant weight basis, W_J^lk(x), and internodal
|
||||
distances, needed to compute varphi_J^lk(x), of eqn 8 of
|
||||
https://arxiv.org/pdf/2006.10503.pdf. The return values of this function
|
||||
can be shared as input across all SE(3)-Transformer layers in a model.
|
||||
|
||||
Args:
|
||||
G: DGL graph instance of type dgl.DGLGraph()
|
||||
max_degree: non-negative int for degree of highest feature-type
|
||||
compute_gradients: controls whether to compute gradients during basis construction
|
||||
Returns:
|
||||
dict of equivariant bases, keys are in form '<d_in><d_out>'
|
||||
vector of relative distances, ordered according to edge ordering of G
|
||||
"""
|
||||
basis = get_basis(G, max_degree, compute_gradients)
|
||||
r = get_r(G)
|
||||
return basis, r
|
||||
|
||||
|
||||
### SE(3) equivariant operations on graphs in DGL
|
||||
|
||||
class GConvSE3(nn.Module):
|
||||
"""A tensor field network layer as a DGL module.
|
||||
|
||||
GConvSE3 stands for a Graph Convolution SE(3)-equivariant layer. It is the
|
||||
equivalent of a linear layer in an MLP, a conv layer in a CNN, or a graph
|
||||
conv layer in a GCN.
|
||||
|
||||
At each node, the activations are split into different "feature types",
|
||||
indexed by the SE(3) representation type: non-negative integers 0, 1, 2, ..
|
||||
"""
|
||||
def __init__(self, f_in, f_out, self_interaction: bool=False, edge_dim: int=0, flavor='skip'):
|
||||
"""SE(3)-equivariant Graph Conv Layer
|
||||
|
||||
Args:
|
||||
f_in: list of tuples [(multiplicities, type),...]
|
||||
f_out: list of tuples [(multiplicities, type),...]
|
||||
self_interaction: include self-interaction in convolution
|
||||
edge_dim: number of dimensions for edge embedding
|
||||
flavor: allows ['TFN', 'skip'], where 'skip' adds a skip connection
|
||||
"""
|
||||
super().__init__()
|
||||
self.f_in = f_in
|
||||
self.f_out = f_out
|
||||
self.edge_dim = edge_dim
|
||||
self.self_interaction = self_interaction
|
||||
self.flavor = flavor
|
||||
|
||||
# Neighbor -> center weights
|
||||
self.kernel_unary = nn.ModuleDict()
|
||||
for (mi, di) in self.f_in.structure:
|
||||
for (mo, do) in self.f_out.structure:
|
||||
self.kernel_unary[f'({di},{do})'] = PairwiseConv(di, mi, do, mo, edge_dim=edge_dim)
|
||||
|
||||
# Center -> center weights
|
||||
self.kernel_self = nn.ParameterDict()
|
||||
if self_interaction:
|
||||
assert self.flavor in ['TFN', 'skip']
|
||||
if self.flavor == 'TFN':
|
||||
for m_out, d_out in self.f_out.structure:
|
||||
W = nn.Parameter(torch.randn(1, m_out, m_out) / np.sqrt(m_out))
|
||||
self.kernel_self[f'{d_out}'] = W
|
||||
elif self.flavor == 'skip':
|
||||
for m_in, d_in in self.f_in.structure:
|
||||
if d_in in self.f_out.degrees:
|
||||
m_out = self.f_out.structure_dict[d_in]
|
||||
W = nn.Parameter(torch.randn(1, m_out, m_in) / np.sqrt(m_in))
|
||||
self.kernel_self[f'{d_in}'] = W
|
||||
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f'GConvSE3(structure={self.f_out}, self_interaction={self.self_interaction})'
|
||||
|
||||
|
||||
def udf_u_mul_e(self, d_out):
|
||||
"""Compute the convolution for a single output feature type.
|
||||
|
||||
This function is set up as a User Defined Function in DGL.
|
||||
|
||||
Args:
|
||||
d_out: output feature type
|
||||
Returns:
|
||||
edge -> node function handle
|
||||
"""
|
||||
def fnc(edges):
|
||||
# Neighbor -> center messages
|
||||
msg = 0
|
||||
for m_in, d_in in self.f_in.structure:
|
||||
src = edges.src[f'{d_in}'].view(-1, m_in*(2*d_in+1), 1)
|
||||
edge = edges.data[f'({d_in},{d_out})']
|
||||
msg = msg + torch.matmul(edge, src)
|
||||
msg = msg.view(msg.shape[0], -1, 2*d_out+1)
|
||||
|
||||
# Center -> center messages
|
||||
if self.self_interaction:
|
||||
if f'{d_out}' in self.kernel_self.keys():
|
||||
if self.flavor == 'TFN':
|
||||
W = self.kernel_self[f'{d_out}']
|
||||
msg = torch.matmul(W, msg)
|
||||
if self.flavor == 'skip':
|
||||
dst = edges.dst[f'{d_out}']
|
||||
W = self.kernel_self[f'{d_out}']
|
||||
msg = msg + torch.matmul(W, dst)
|
||||
|
||||
return {'msg': msg.view(msg.shape[0], -1, 2*d_out+1)}
|
||||
return fnc
|
||||
|
||||
def forward(self, h, G=None, r=None, basis=None, **kwargs):
|
||||
"""Forward pass of the linear layer
|
||||
|
||||
Args:
|
||||
G: minibatch of (homo)graphs
|
||||
h: dict of features
|
||||
r: inter-atomic distances
|
||||
basis: pre-computed Q * Y
|
||||
Returns:
|
||||
tensor with new features [B, n_points, n_features_out]
|
||||
"""
|
||||
with G.local_scope():
|
||||
# Add node features to local graph scope
|
||||
for k, v in h.items():
|
||||
G.ndata[k] = v
|
||||
|
||||
# Add edge features
|
||||
if 'w' in G.edata.keys():
|
||||
w = G.edata['w']
|
||||
feat = torch.cat([w, r], -1)
|
||||
else:
|
||||
feat = torch.cat([r, ], -1)
|
||||
|
||||
for (mi, di) in self.f_in.structure:
|
||||
for (mo, do) in self.f_out.structure:
|
||||
etype = f'({di},{do})'
|
||||
G.edata[etype] = self.kernel_unary[etype](feat, basis)
|
||||
|
||||
# Perform message-passing for each output feature type
|
||||
for d in self.f_out.degrees:
|
||||
G.update_all(self.udf_u_mul_e(d), fn.mean('msg', f'out{d}'))
|
||||
|
||||
return {f'{d}': G.ndata[f'out{d}'] for d in self.f_out.degrees}
|
||||
|
||||
|
||||
class RadialFunc(nn.Module):
|
||||
"""NN parameterized radial profile function."""
|
||||
def __init__(self, num_freq, in_dim, out_dim, edge_dim: int=0):
|
||||
"""NN parameterized radial profile function.
|
||||
|
||||
Args:
|
||||
num_freq: number of output frequencies
|
||||
in_dim: multiplicity of input (num input channels)
|
||||
out_dim: multiplicity of output (num output channels)
|
||||
edge_dim: number of dimensions for edge embedding
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_freq = num_freq
|
||||
self.in_dim = in_dim
|
||||
self.mid_dim = 32
|
||||
self.out_dim = out_dim
|
||||
self.edge_dim = edge_dim
|
||||
|
||||
self.net = nn.Sequential(nn.Linear(self.edge_dim+1,self.mid_dim),
|
||||
BN(self.mid_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.mid_dim,self.mid_dim),
|
||||
BN(self.mid_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.mid_dim,self.num_freq*in_dim*out_dim))
|
||||
|
||||
nn.init.kaiming_uniform_(self.net[0].weight)
|
||||
nn.init.kaiming_uniform_(self.net[3].weight)
|
||||
nn.init.kaiming_uniform_(self.net[6].weight)
|
||||
|
||||
def __repr__(self):
|
||||
return f"RadialFunc(edge_dim={self.edge_dim}, in_dim={self.in_dim}, out_dim={self.out_dim})"
|
||||
|
||||
def forward(self, x):
|
||||
y = self.net(x)
|
||||
return y.view(-1, self.out_dim, 1, self.in_dim, 1, self.num_freq)
|
||||
|
||||
|
||||
class PairwiseConv(nn.Module):
|
||||
"""SE(3)-equivariant convolution between two single-type features"""
|
||||
def __init__(self, degree_in: int, nc_in: int, degree_out: int,
|
||||
nc_out: int, edge_dim: int=0):
|
||||
"""SE(3)-equivariant convolution between a pair of feature types.
|
||||
|
||||
This layer performs a convolution from nc_in features of type degree_in
|
||||
to nc_out features of type degree_out.
|
||||
|
||||
Args:
|
||||
degree_in: degree of input fiber
|
||||
nc_in: number of channels on input
|
||||
degree_out: degree of out order
|
||||
nc_out: number of channels on output
|
||||
edge_dim: number of dimensions for edge embedding
|
||||
"""
|
||||
super().__init__()
|
||||
# Log settings
|
||||
self.degree_in = degree_in
|
||||
self.degree_out = degree_out
|
||||
self.nc_in = nc_in
|
||||
self.nc_out = nc_out
|
||||
|
||||
# Functions of the degree
|
||||
self.num_freq = 2*min(degree_in, degree_out) + 1
|
||||
self.d_out = 2*degree_out + 1
|
||||
self.edge_dim = edge_dim
|
||||
|
||||
# Radial profile function
|
||||
self.rp = RadialFunc(self.num_freq, nc_in, nc_out, self.edge_dim)
|
||||
|
||||
def forward(self, feat, basis):
|
||||
# Get radial weights
|
||||
R = self.rp(feat)
|
||||
kernel = torch.sum(R * basis[f'{self.degree_in},{self.degree_out}'], -1)
|
||||
return kernel.view(kernel.shape[0], self.d_out*self.nc_out, -1)
|
||||
|
||||
|
||||
class G1x1SE3(nn.Module):
|
||||
"""Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.
|
||||
|
||||
This is equivalent to a self-interaction layer in TensorField Networks.
|
||||
"""
|
||||
def __init__(self, f_in, f_out, learnable=True):
|
||||
"""SE(3)-equivariant 1x1 convolution.
|
||||
|
||||
Args:
|
||||
f_in: input Fiber() of feature multiplicities and types
|
||||
f_out: output Fiber() of feature multiplicities and types
|
||||
"""
|
||||
super().__init__()
|
||||
self.f_in = f_in
|
||||
self.f_out = f_out
|
||||
|
||||
# Linear mappings: 1 per output feature type
|
||||
self.transform = nn.ParameterDict()
|
||||
for m_out, d_out in self.f_out.structure:
|
||||
m_in = self.f_in.structure_dict[d_out]
|
||||
self.transform[str(d_out)] = nn.Parameter(torch.randn(m_out, m_in) / np.sqrt(m_in), requires_grad=learnable)
|
||||
|
||||
def __repr__(self):
|
||||
return f"G1x1SE3(structure={self.f_out})"
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
output = {}
|
||||
for k, v in features.items():
|
||||
if str(k) in self.transform.keys():
|
||||
output[k] = torch.matmul(self.transform[str(k)], v)
|
||||
return output
|
||||
|
||||
|
||||
class GNormBias(nn.Module):
|
||||
"""Norm-based SE(3)-equivariant nonlinearity with only learned biases."""
|
||||
|
||||
def __init__(self, fiber, nonlin=nn.ReLU(inplace=True),
|
||||
num_layers: int = 0):
|
||||
"""Initializer.
|
||||
|
||||
Args:
|
||||
fiber: Fiber() of feature multiplicities and types
|
||||
nonlin: nonlinearity to use everywhere
|
||||
num_layers: non-negative number of linear layers in fnc
|
||||
"""
|
||||
super().__init__()
|
||||
self.fiber = fiber
|
||||
self.nonlin = nonlin
|
||||
self.num_layers = num_layers
|
||||
|
||||
# Regularization for computing phase: gradients explode otherwise
|
||||
self.eps = 1e-12
|
||||
|
||||
# Norm mappings: 1 per feature type
|
||||
self.bias = nn.ParameterDict()
|
||||
for m, d in self.fiber.structure:
|
||||
self.bias[str(d)] = nn.Parameter(torch.randn(m).view(1, m))
|
||||
|
||||
def __repr__(self):
|
||||
return f"GNormTFN()"
|
||||
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
output = {}
|
||||
for k, v in features.items():
|
||||
# Compute the norms and normalized features
|
||||
# v shape: [...,m , 2*k+1]
|
||||
norm = v.norm(2, -1, keepdim=True).clamp_min(self.eps).expand_as(v)
|
||||
phase = v / norm
|
||||
|
||||
# Transform on norms
|
||||
# transformed = self.transform[str(k)](norm[..., 0]).unsqueeze(-1)
|
||||
transformed = self.nonlin(norm[..., 0] + self.bias[str(k)])
|
||||
|
||||
# Nonlinearity on norm
|
||||
output[k] = (transformed.unsqueeze(-1) * phase).view(*v.shape)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class GAttentiveSelfInt(nn.Module):
|
||||
|
||||
def __init__(self, f_in, f_out):
|
||||
"""SE(3)-equivariant 1x1 convolution.
|
||||
|
||||
Args:
|
||||
f_in: input Fiber() of feature multiplicities and types
|
||||
f_out: output Fiber() of feature multiplicities and types
|
||||
"""
|
||||
super().__init__()
|
||||
self.f_in = f_in
|
||||
self.f_out = f_out
|
||||
self.nonlin = nn.LeakyReLU()
|
||||
self.num_layers = 2
|
||||
self.eps = 1e-12 # regularisation for phase: gradients explode otherwise
|
||||
|
||||
# one network for attention weights per degree
|
||||
self.transform = nn.ModuleDict()
|
||||
for o, m_in in self.f_in.structure_dict.items():
|
||||
m_out = self.f_out.structure_dict[o]
|
||||
self.transform[str(o)] = self._build_net(m_in, m_out)
|
||||
|
||||
def __repr__(self):
|
||||
return f"AttentiveSelfInteractionSE3(in={self.f_in}, out={self.f_out})"
|
||||
|
||||
def _build_net(self, m_in: int, m_out):
|
||||
n_hidden = m_in * m_out
|
||||
cur_inpt = m_in * m_in
|
||||
net = []
|
||||
for i in range(1, self.num_layers):
|
||||
net.append(nn.LayerNorm(int(cur_inpt)))
|
||||
net.append(self.nonlin)
|
||||
# TODO: implement cleaner init
|
||||
net.append(
|
||||
nn.Linear(cur_inpt, n_hidden, bias=(i == self.num_layers - 1)))
|
||||
nn.init.kaiming_uniform_(net[-1].weight)
|
||||
cur_inpt = n_hidden
|
||||
return nn.Sequential(*net)
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
output = {}
|
||||
for k, v in features.items():
|
||||
# v shape: [..., m, 2*k+1]
|
||||
first_dims = v.shape[:-2]
|
||||
m_in = self.f_in.structure_dict[int(k)]
|
||||
m_out = self.f_out.structure_dict[int(k)]
|
||||
assert v.shape[-2] == m_in
|
||||
assert v.shape[-1] == 2 * int(k) + 1
|
||||
|
||||
# Compute the norms and normalized features
|
||||
#norm = v.norm(p=2, dim=-1, keepdim=True).clamp_min(self.eps).expand_as(v)
|
||||
#phase = v / norm # [..., m, 2*k+1]
|
||||
scalars = torch.einsum('...ac,...bc->...ab', [v, v]) # [..., m_in, m_in]
|
||||
scalars = scalars.view(*first_dims, m_in*m_in) # [..., m_in*m_in]
|
||||
sign = scalars.sign()
|
||||
scalars = scalars.abs_().clamp_min(self.eps)
|
||||
scalars *= sign
|
||||
|
||||
# perform attention
|
||||
att_weights = self.transform[str(k)](scalars) # [..., m_out*m_in]
|
||||
att_weights = att_weights.view(*first_dims, m_out, m_in) # [..., m_out, m_in]
|
||||
att_weights = F.softmax(input=att_weights, dim=-1)
|
||||
# shape [..., m_out, 2*k+1]
|
||||
# output[k] = torch.einsum('...nm,...md->...nd', [att_weights, phase])
|
||||
output[k] = torch.einsum('...nm,...md->...nd', [att_weights, v])
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
class GNormSE3(nn.Module):
|
||||
"""Graph Norm-based SE(3)-equivariant nonlinearity.
|
||||
|
||||
Nonlinearities are important in SE(3) equivariant GCNs. They are also quite
|
||||
expensive to compute, so it is convenient for them to share resources with
|
||||
other layers, such as normalization. The general workflow is as follows:
|
||||
|
||||
> for feature type in features:
|
||||
> norm, phase <- feature
|
||||
> output = fnc(norm) * phase
|
||||
|
||||
where fnc: {R+}^m -> R^m is a learnable map from m norms to m scalars.
|
||||
"""
|
||||
def __init__(self, fiber, nonlin=nn.ReLU(inplace=True), num_layers: int=0):
|
||||
"""Initializer.
|
||||
|
||||
Args:
|
||||
fiber: Fiber() of feature multiplicities and types
|
||||
nonlin: nonlinearity to use everywhere
|
||||
num_layers: non-negative number of linear layers in fnc
|
||||
"""
|
||||
super().__init__()
|
||||
self.fiber = fiber
|
||||
self.nonlin = nonlin
|
||||
self.num_layers = num_layers
|
||||
|
||||
# Regularization for computing phase: gradients explode otherwise
|
||||
self.eps = 1e-12
|
||||
|
||||
# Norm mappings: 1 per feature type
|
||||
self.transform = nn.ModuleDict()
|
||||
for m, d in self.fiber.structure:
|
||||
self.transform[str(d)] = self._build_net(int(m))
|
||||
|
||||
def __repr__(self):
|
||||
return f"GNormSE3(num_layers={self.num_layers}, nonlin={self.nonlin})"
|
||||
|
||||
def _build_net(self, m: int):
|
||||
net = []
|
||||
for i in range(self.num_layers):
|
||||
net.append(BN(int(m)))
|
||||
net.append(self.nonlin)
|
||||
# TODO: implement cleaner init
|
||||
net.append(nn.Linear(m, m, bias=(i==self.num_layers-1)))
|
||||
nn.init.kaiming_uniform_(net[-1].weight)
|
||||
if self.num_layers == 0:
|
||||
net.append(BN(int(m)))
|
||||
net.append(self.nonlin)
|
||||
return nn.Sequential(*net)
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
output = {}
|
||||
for k, v in features.items():
|
||||
# Compute the norms and normalized features
|
||||
# v shape: [...,m , 2*k+1]
|
||||
norm = v.norm(2, -1, keepdim=True).clamp_min(self.eps).expand_as(v)
|
||||
phase = v / norm
|
||||
|
||||
# Transform on norms
|
||||
transformed = self.transform[str(k)](norm[...,0]).unsqueeze(-1)
|
||||
|
||||
# Nonlinearity on norm
|
||||
output[k] = (transformed * phase).view(*v.shape)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class BN(nn.Module):
|
||||
"""SE(3)-equvariant batch/layer normalization"""
|
||||
def __init__(self, m):
|
||||
"""SE(3)-equvariant batch/layer normalization
|
||||
|
||||
Args:
|
||||
m: int for number of output channels
|
||||
"""
|
||||
super().__init__()
|
||||
self.bn = nn.LayerNorm(m)
|
||||
|
||||
def forward(self, x):
|
||||
return self.bn(x)
|
||||
|
||||
|
||||
class GConvSE3Partial(nn.Module):
|
||||
"""Graph SE(3)-equivariant node -> edge layer"""
|
||||
def __init__(self, f_in, f_out, edge_dim: int=0, x_ij=None):
|
||||
"""SE(3)-equivariant partial convolution.
|
||||
|
||||
A partial convolution computes the inner product between a kernel and
|
||||
each input channel, without summing over the result from each input
|
||||
channel. This unfolded structure makes it amenable to be used for
|
||||
computing the value-embeddings of the attention mechanism.
|
||||
|
||||
Args:
|
||||
f_in: list of tuples [(multiplicities, type),...]
|
||||
f_out: list of tuples [(multiplicities, type),...]
|
||||
"""
|
||||
super().__init__()
|
||||
self.f_out = f_out
|
||||
self.edge_dim = edge_dim
|
||||
|
||||
# adding/concatinating relative position to feature vectors
|
||||
# 'cat' concatenates relative position & existing feature vector
|
||||
# 'add' adds it, but only if multiplicity > 1
|
||||
assert x_ij in [None, 'cat', 'add']
|
||||
self.x_ij = x_ij
|
||||
if x_ij == 'cat':
|
||||
self.f_in = Fiber.combine(f_in, Fiber(structure=[(1,1)]))
|
||||
else:
|
||||
self.f_in = f_in
|
||||
|
||||
# Node -> edge weights
|
||||
self.kernel_unary = nn.ModuleDict()
|
||||
for (mi, di) in self.f_in.structure:
|
||||
for (mo, do) in self.f_out.structure:
|
||||
self.kernel_unary[f'({di},{do})'] = PairwiseConv(di, mi, do, mo, edge_dim=edge_dim)
|
||||
|
||||
def __repr__(self):
|
||||
return f'GConvSE3Partial(structure={self.f_out})'
|
||||
|
||||
def udf_u_mul_e(self, d_out):
|
||||
"""Compute the partial convolution for a single output feature type.
|
||||
|
||||
This function is set up as a User Defined Function in DGL.
|
||||
|
||||
Args:
|
||||
d_out: output feature type
|
||||
Returns:
|
||||
node -> edge function handle
|
||||
"""
|
||||
def fnc(edges):
|
||||
# Neighbor -> center messages
|
||||
msg = 0
|
||||
for m_in, d_in in self.f_in.structure:
|
||||
# if type 1 and flag set, add relative position as feature
|
||||
if self.x_ij == 'cat' and d_in == 1:
|
||||
# relative positions
|
||||
rel = (edges.dst['x'] - edges.src['x']).view(-1, 3, 1)
|
||||
m_ori = m_in - 1
|
||||
if m_ori == 0:
|
||||
# no type 1 input feature, just use relative position
|
||||
src = rel
|
||||
else:
|
||||
# features of src node, shape [edges, m_in*(2l+1), 1]
|
||||
src = edges.src[f'{d_in}'].view(-1, m_ori*(2*d_in+1), 1)
|
||||
# add to feature vector
|
||||
src = torch.cat([src, rel], dim=1)
|
||||
elif self.x_ij == 'add' and d_in == 1 and m_in > 1:
|
||||
src = edges.src[f'{d_in}'].view(-1, m_in*(2*d_in+1), 1)
|
||||
rel = (edges.dst['x'] - edges.src['x']).view(-1, 3, 1)
|
||||
src[..., :3, :1] = src[..., :3, :1] + rel
|
||||
else:
|
||||
src = edges.src[f'{d_in}'].view(-1, m_in*(2*d_in+1), 1)
|
||||
edge = edges.data[f'({d_in},{d_out})']
|
||||
msg = msg + torch.matmul(edge, src)
|
||||
msg = msg.view(msg.shape[0], -1, 2*d_out+1)
|
||||
|
||||
return {f'out{d_out}': msg.view(msg.shape[0], -1, 2*d_out+1)}
|
||||
return fnc
|
||||
|
||||
def forward(self, h, G=None, r=None, basis=None, **kwargs):
|
||||
"""Forward pass of the linear layer
|
||||
|
||||
Args:
|
||||
h: dict of node-features
|
||||
G: minibatch of (homo)graphs
|
||||
r: inter-atomic distances
|
||||
basis: pre-computed Q * Y
|
||||
Returns:
|
||||
tensor with new features [B, n_points, n_features_out]
|
||||
"""
|
||||
with G.local_scope():
|
||||
# Add node features to local graph scope
|
||||
for k, v in h.items():
|
||||
G.ndata[k] = v
|
||||
|
||||
# Add edge features
|
||||
if 'w' in G.edata.keys():
|
||||
w = G.edata['w'] # shape: [#edges_in_batch, #bond_types]
|
||||
feat = torch.cat([w, r], -1)
|
||||
else:
|
||||
feat = torch.cat([r, ], -1)
|
||||
for (mi, di) in self.f_in.structure:
|
||||
for (mo, do) in self.f_out.structure:
|
||||
etype = f'({di},{do})'
|
||||
G.edata[etype] = self.kernel_unary[etype](feat, basis)
|
||||
|
||||
# Perform message-passing for each output feature type
|
||||
for d in self.f_out.degrees:
|
||||
G.apply_edges(self.udf_u_mul_e(d))
|
||||
|
||||
return {f'{d}': G.edata[f'out{d}'] for d in self.f_out.degrees}
|
||||
|
||||
|
||||
class GMABSE3(nn.Module):
|
||||
"""An SE(3)-equivariant multi-headed self-attention module for DGL graphs."""
|
||||
def __init__(self, f_value: Fiber, f_key: Fiber, n_heads: int):
|
||||
"""SE(3)-equivariant MAB (multi-headed attention block) layer.
|
||||
|
||||
Args:
|
||||
f_value: Fiber() object for value-embeddings
|
||||
f_key: Fiber() object for key-embeddings
|
||||
n_heads: number of heads
|
||||
"""
|
||||
super().__init__()
|
||||
self.f_value = f_value
|
||||
self.f_key = f_key
|
||||
self.n_heads = n_heads
|
||||
self.new_dgl = version.parse(dgl.__version__) > version.parse('0.4.4')
|
||||
|
||||
def __repr__(self):
|
||||
return f'GMABSE3(n_heads={self.n_heads}, structure={self.f_value})'
|
||||
|
||||
def udf_u_mul_e(self, d_out):
|
||||
"""Compute the weighted sum for a single output feature type.
|
||||
|
||||
This function is set up as a User Defined Function in DGL.
|
||||
|
||||
Args:
|
||||
d_out: output feature type
|
||||
Returns:
|
||||
edge -> node function handle
|
||||
"""
|
||||
def fnc(edges):
|
||||
# Neighbor -> center messages
|
||||
attn = edges.data['a']
|
||||
value = edges.data[f'v{d_out}']
|
||||
|
||||
# Apply attention weights
|
||||
msg = attn.unsqueeze(-1).unsqueeze(-1) * value
|
||||
|
||||
return {'m': msg}
|
||||
return fnc
|
||||
|
||||
def forward(self, v, k: Dict=None, q: Dict=None, G=None, **kwargs):
|
||||
"""Forward pass of the linear layer
|
||||
|
||||
Args:
|
||||
G: minibatch of (homo)graphs
|
||||
v: dict of value edge-features
|
||||
k: dict of key edge-features
|
||||
q: dict of query node-features
|
||||
Returns:
|
||||
tensor with new features [B, n_points, n_features_out]
|
||||
"""
|
||||
with G.local_scope():
|
||||
# Add node features to local graph scope
|
||||
## We use the stacked tensor representation for attention
|
||||
for m, d in self.f_value.structure:
|
||||
G.edata[f'v{d}'] = v[f'{d}'].view(-1, self.n_heads, m//self.n_heads, 2*d+1)
|
||||
G.edata['k'] = fiber2head(k, self.n_heads, self.f_key, squeeze=True) # [edges, heads, channels](?)
|
||||
G.ndata['q'] = fiber2head(q, self.n_heads, self.f_key, squeeze=True) # [nodes, heads, channels](?)
|
||||
|
||||
# Compute attention weights
|
||||
## Inner product between (key) neighborhood and (query) center
|
||||
G.apply_edges(fn.e_dot_v('k', 'q', 'e'))
|
||||
|
||||
## Apply softmax
|
||||
e = G.edata.pop('e')
|
||||
if self.new_dgl:
|
||||
# in dgl 5.3, e has an extra dimension compared to dgl 4.3
|
||||
# the following, we get rid of this be reshaping
|
||||
n_edges = G.edata['k'].shape[0]
|
||||
e = e.view([n_edges, self.n_heads])
|
||||
e = e / np.sqrt(self.f_key.n_features)
|
||||
G.edata['a'] = edge_softmax(G, e)
|
||||
|
||||
# Perform attention-weighted message-passing
|
||||
for d in self.f_value.degrees:
|
||||
G.update_all(self.udf_u_mul_e(d), fn.sum('m', f'out{d}'))
|
||||
|
||||
output = {}
|
||||
for m, d in self.f_value.structure:
|
||||
output[f'{d}'] = G.ndata[f'out{d}'].view(-1, m, 2*d+1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class GSE3Res(nn.Module):
|
||||
"""Graph attention block with SE(3)-equivariance and skip connection"""
|
||||
def __init__(self, f_in: Fiber, f_out: Fiber, edge_dim: int=0, div: float=4,
|
||||
n_heads: int=1, learnable_skip=True, skip='cat', selfint='1x1', x_ij=None):
|
||||
super().__init__()
|
||||
self.f_in = f_in
|
||||
self.f_out = f_out
|
||||
self.div = div
|
||||
self.n_heads = n_heads
|
||||
self.skip = skip # valid: 'cat', 'sum', None
|
||||
|
||||
# f_mid_out has same structure as 'f_out' but #channels divided by 'div'
|
||||
# this will be used for the values
|
||||
f_mid_out = {k: int(v // div) for k, v in self.f_out.structure_dict.items()}
|
||||
self.f_mid_out = Fiber(dictionary=f_mid_out)
|
||||
|
||||
# f_mid_in has same structure as f_mid_out, but only degrees which are in f_in
|
||||
# this will be used for keys and queries
|
||||
# (queries are merely projected, hence degrees have to match input)
|
||||
f_mid_in = {d: m for d, m in f_mid_out.items() if d in self.f_in.degrees}
|
||||
self.f_mid_in = Fiber(dictionary=f_mid_in)
|
||||
|
||||
self.edge_dim = edge_dim
|
||||
|
||||
self.GMAB = nn.ModuleDict()
|
||||
|
||||
# Projections
|
||||
self.GMAB['v'] = GConvSE3Partial(f_in, self.f_mid_out, edge_dim=edge_dim, x_ij=x_ij)
|
||||
self.GMAB['k'] = GConvSE3Partial(f_in, self.f_mid_in, edge_dim=edge_dim, x_ij=x_ij)
|
||||
self.GMAB['q'] = G1x1SE3(f_in, self.f_mid_in)
|
||||
|
||||
# Attention
|
||||
self.GMAB['attn'] = GMABSE3(self.f_mid_out, self.f_mid_in, n_heads=n_heads)
|
||||
|
||||
# Skip connections
|
||||
if self.skip == 'cat':
|
||||
self.cat = GCat(self.f_mid_out, f_in)
|
||||
if selfint == 'att':
|
||||
self.project = GAttentiveSelfInt(self.cat.f_out, f_out)
|
||||
elif selfint == '1x1':
|
||||
self.project = G1x1SE3(self.cat.f_out, f_out, learnable=learnable_skip)
|
||||
elif self.skip == 'sum':
|
||||
self.project = G1x1SE3(self.f_mid_out, f_out, learnable=learnable_skip)
|
||||
self.add = GSum(f_out, f_in)
|
||||
# the following checks whether the skip connection would change
|
||||
# the output fibre strucure; the reason can be that the input has
|
||||
# more channels than the ouput (for at least one degree); this would
|
||||
# then cause a (hard to debug) error in the next layer
|
||||
assert self.add.f_out.structure_dict == f_out.structure_dict, \
|
||||
'skip connection would change output structure'
|
||||
|
||||
def forward(self, features, G, **kwargs):
|
||||
# Embeddings
|
||||
v = self.GMAB['v'](features, G=G, **kwargs)
|
||||
k = self.GMAB['k'](features, G=G, **kwargs)
|
||||
q = self.GMAB['q'](features, G=G)
|
||||
|
||||
# Attention
|
||||
z = self.GMAB['attn'](v, k=k, q=q, G=G)
|
||||
|
||||
if self.skip == 'cat':
|
||||
z = self.cat(z, features)
|
||||
z = self.project(z)
|
||||
elif self.skip == 'sum':
|
||||
# Skip + residual
|
||||
z = self.project(z)
|
||||
z = self.add(z, features)
|
||||
return z
|
||||
|
||||
### Helper and wrapper functions
|
||||
|
||||
class GSum(nn.Module):
|
||||
"""SE(3)-equvariant graph residual sum function."""
|
||||
def __init__(self, f_x: Fiber, f_y: Fiber):
|
||||
"""SE(3)-equvariant graph residual sum function.
|
||||
|
||||
Args:
|
||||
f_x: Fiber() object for fiber of summands
|
||||
f_y: Fiber() object for fiber of summands
|
||||
"""
|
||||
super().__init__()
|
||||
self.f_x = f_x
|
||||
self.f_y = f_y
|
||||
self.f_out = Fiber.combine_max(f_x, f_y)
|
||||
|
||||
def __repr__(self):
|
||||
return f"GSum(structure={self.f_out})"
|
||||
|
||||
def forward(self, x, y):
|
||||
out = {}
|
||||
for k in self.f_out.degrees:
|
||||
k = str(k)
|
||||
if (k in x) and (k in y):
|
||||
if x[k].shape[1] > y[k].shape[1]:
|
||||
diff = x[k].shape[1] - y[k].shape[1]
|
||||
zeros = torch.zeros(x[k].shape[0], diff, x[k].shape[2]).to(y[k].device)
|
||||
y[k] = torch.cat([y[k], zeros], 1)
|
||||
elif x[k].shape[1] < y[k].shape[1]:
|
||||
diff = y[k].shape[1] - x[k].shape[1]
|
||||
zeros = torch.zeros(x[k].shape[0], diff, x[k].shape[2]).to(y[k].device)
|
||||
x[k] = torch.cat([x[k], zeros], 1)
|
||||
|
||||
out[k] = x[k] + y[k]
|
||||
elif k in x:
|
||||
out[k] = x[k]
|
||||
elif k in y:
|
||||
out[k] = y[k]
|
||||
return out
|
||||
|
||||
|
||||
class GCat(nn.Module):
|
||||
"""Concat only degrees which are in f_x"""
|
||||
def __init__(self, f_x: Fiber, f_y: Fiber):
|
||||
super().__init__()
|
||||
self.f_x = f_x
|
||||
self.f_y = f_y
|
||||
f_out = {}
|
||||
for k in f_x.degrees:
|
||||
f_out[k] = f_x.dict[k]
|
||||
if k in f_y.degrees:
|
||||
f_out[k] += f_y.dict[k]
|
||||
self.f_out = Fiber(dictionary=f_out)
|
||||
|
||||
def __repr__(self):
|
||||
return f"GCat(structure={self.f_out})"
|
||||
|
||||
def forward(self, x, y):
|
||||
out = {}
|
||||
for k in self.f_out.degrees:
|
||||
k = str(k)
|
||||
if k in y:
|
||||
out[k] = torch.cat([x[k], y[k]], 1)
|
||||
else:
|
||||
out[k] = x[k]
|
||||
return out
|
||||
|
||||
|
||||
class GAvgPooling(nn.Module):
|
||||
"""Graph Average Pooling module."""
|
||||
def __init__(self, type='0'):
|
||||
super().__init__()
|
||||
self.pool = AvgPooling()
|
||||
self.type = type
|
||||
|
||||
def forward(self, features, G, **kwargs):
|
||||
if self.type == '0':
|
||||
h = features['0'][...,-1]
|
||||
pooled = self.pool(G, h)
|
||||
elif self.type == '1':
|
||||
pooled = []
|
||||
for i in range(3):
|
||||
h_i = features['1'][..., i]
|
||||
pooled.append(self.pool(G, h_i).unsqueeze(-1))
|
||||
pooled = torch.cat(pooled, axis=-1)
|
||||
pooled = {'1': pooled}
|
||||
else:
|
||||
print('GAvgPooling for type > 0 not implemented')
|
||||
exit()
|
||||
return pooled
|
||||
|
||||
|
||||
class GMaxPooling(nn.Module):
|
||||
"""Graph Max Pooling module."""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pool = MaxPooling()
|
||||
|
||||
def forward(self, features, G, **kwargs):
|
||||
h = features['0'][...,-1]
|
||||
return self.pool(G, h)
|
||||
|
||||
|
91
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/ffindex.py
Normal file
91
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/ffindex.py
Normal file
|
@ -0,0 +1,91 @@
|
|||
#!/usr/bin/env python
|
||||
# https://raw.githubusercontent.com/ahcm/ffindex/master/python/ffindex.py
|
||||
|
||||
'''
|
||||
Created on Apr 30, 2014
|
||||
|
||||
@author: meiermark
|
||||
'''
|
||||
|
||||
|
||||
import sys
|
||||
import mmap
|
||||
from collections import namedtuple
|
||||
|
||||
FFindexEntry = namedtuple("FFindexEntry", "name, offset, length")
|
||||
|
||||
|
||||
def read_index(ffindex_filename):
|
||||
entries = []
|
||||
|
||||
fh = open(ffindex_filename)
|
||||
for line in fh:
|
||||
tokens = line.split("\t")
|
||||
entries.append(FFindexEntry(tokens[0], int(tokens[1]), int(tokens[2])))
|
||||
fh.close()
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def read_data(ffdata_filename):
|
||||
fh = open(ffdata_filename, "rb")
|
||||
data = mmap.mmap(fh.fileno(), 0, prot=mmap.PROT_READ)
|
||||
fh.close()
|
||||
return data
|
||||
|
||||
|
||||
def get_entry_by_name(name, index):
|
||||
#TODO: bsearch
|
||||
for entry in index:
|
||||
if(name == entry.name):
|
||||
return entry
|
||||
return None
|
||||
|
||||
|
||||
def read_entry_lines(entry, data):
|
||||
lines = data[entry.offset:entry.offset + entry.length - 1].decode("utf-8").split("\n")
|
||||
return lines
|
||||
|
||||
|
||||
def read_entry_data(entry, data):
|
||||
return data[entry.offset:entry.offset + entry.length - 1]
|
||||
|
||||
|
||||
def write_entry(entries, data_fh, entry_name, offset, data):
|
||||
data_fh.write(data[:-1])
|
||||
data_fh.write(bytearray(1))
|
||||
|
||||
entry = FFindexEntry(entry_name, offset, len(data))
|
||||
entries.append(entry)
|
||||
|
||||
return offset + len(data)
|
||||
|
||||
|
||||
def write_entry_with_file(entries, data_fh, entry_name, offset, file_name):
|
||||
with open(file_name, "rb") as fh:
|
||||
data = bytearray(fh.read())
|
||||
return write_entry(entries, data_fh, entry_name, offset, data)
|
||||
|
||||
|
||||
def finish_db(entries, ffindex_filename, data_fh):
|
||||
data_fh.close()
|
||||
write_entries_to_db(entries, ffindex_filename)
|
||||
|
||||
|
||||
def write_entries_to_db(entries, ffindex_filename):
|
||||
sorted(entries, key=lambda x: x.name)
|
||||
index_fh = open(ffindex_filename, "w")
|
||||
|
||||
for entry in entries:
|
||||
index_fh.write("{name:.64}\t{offset}\t{length}\n".format(name=entry.name, offset=entry.offset, length=entry.length))
|
||||
|
||||
index_fh.close()
|
||||
|
||||
|
||||
def write_entry_to_file(entry, data, file):
|
||||
lines = read_lines(entry, data)
|
||||
|
||||
fh = open(file, "w")
|
||||
for line in lines:
|
||||
fh.write(line+"\n")
|
||||
fh.close()
|
216
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/kinematics.py
Normal file
216
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/kinematics.py
Normal file
|
@ -0,0 +1,216 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
PARAMS = {
|
||||
"DMIN" : 2.0,
|
||||
"DMAX" : 20.0,
|
||||
"DBINS" : 36,
|
||||
"ABINS" : 36,
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
def get_pair_dist(a, b):
|
||||
"""calculate pair distances between two sets of points
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a,b : pytorch tensors of shape [batch,nres,3]
|
||||
store Cartesian coordinates of two sets of atoms
|
||||
Returns
|
||||
-------
|
||||
dist : pytorch tensor of shape [batch,nres,nres]
|
||||
stores paitwise distances between atoms in a and b
|
||||
"""
|
||||
|
||||
dist = torch.cdist(a, b, p=2)
|
||||
return dist
|
||||
|
||||
# ============================================================
|
||||
def get_ang(a, b, c):
|
||||
"""calculate planar angles for all consecutive triples (a[i],b[i],c[i])
|
||||
from Cartesian coordinates of three sets of atoms a,b,c
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a,b,c : pytorch tensors of shape [batch,nres,3]
|
||||
store Cartesian coordinates of three sets of atoms
|
||||
Returns
|
||||
-------
|
||||
ang : pytorch tensor of shape [batch,nres]
|
||||
stores resulting planar angles
|
||||
"""
|
||||
v = a - b
|
||||
w = c - b
|
||||
v /= torch.norm(v, dim=-1, keepdim=True)
|
||||
w /= torch.norm(w, dim=-1, keepdim=True)
|
||||
vw = torch.sum(v*w, dim=-1)
|
||||
|
||||
return torch.acos(vw)
|
||||
|
||||
# ============================================================
|
||||
def get_dih(a, b, c, d):
|
||||
"""calculate dihedral angles for all consecutive quadruples (a[i],b[i],c[i],d[i])
|
||||
given Cartesian coordinates of four sets of atoms a,b,c,d
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a,b,c,d : pytorch tensors of shape [batch,nres,3]
|
||||
store Cartesian coordinates of four sets of atoms
|
||||
Returns
|
||||
-------
|
||||
dih : pytorch tensor of shape [batch,nres]
|
||||
stores resulting dihedrals
|
||||
"""
|
||||
b0 = a - b
|
||||
b1 = c - b
|
||||
b2 = d - c
|
||||
|
||||
b1 /= torch.norm(b1, dim=-1, keepdim=True)
|
||||
|
||||
v = b0 - torch.sum(b0*b1, dim=-1, keepdim=True)*b1
|
||||
w = b2 - torch.sum(b2*b1, dim=-1, keepdim=True)*b1
|
||||
|
||||
x = torch.sum(v*w, dim=-1)
|
||||
y = torch.sum(torch.cross(b1,v,dim=-1)*w, dim=-1)
|
||||
|
||||
return torch.atan2(y, x)
|
||||
|
||||
|
||||
# ============================================================
|
||||
def xyz_to_c6d(xyz, params=PARAMS):
|
||||
"""convert cartesian coordinates into 2d distance
|
||||
and orientation maps
|
||||
|
||||
Parameters
|
||||
----------
|
||||
xyz : pytorch tensor of shape [batch,nres,3,3]
|
||||
stores Cartesian coordinates of backbone N,Ca,C atoms
|
||||
Returns
|
||||
-------
|
||||
c6d : pytorch tensor of shape [batch,nres,nres,4]
|
||||
stores stacked dist,omega,theta,phi 2D maps
|
||||
"""
|
||||
|
||||
batch = xyz.shape[0]
|
||||
nres = xyz.shape[1]
|
||||
|
||||
# three anchor atoms
|
||||
N = xyz[:,:,0]
|
||||
Ca = xyz[:,:,1]
|
||||
C = xyz[:,:,2]
|
||||
|
||||
# recreate Cb given N,Ca,C
|
||||
b = Ca - N
|
||||
c = C - Ca
|
||||
a = torch.cross(b, c, dim=-1)
|
||||
Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
|
||||
|
||||
# 6d coordinates order: (dist,omega,theta,phi)
|
||||
c6d = torch.zeros([batch,nres,nres,4],dtype=xyz.dtype,device=xyz.device)
|
||||
|
||||
dist = get_pair_dist(Cb,Cb)
|
||||
dist[torch.isnan(dist)] = 999.9
|
||||
c6d[...,0] = dist + 999.9*torch.eye(nres,device=xyz.device)[None,...]
|
||||
b,i,j = torch.where(c6d[...,0]<params['DMAX'])
|
||||
|
||||
c6d[b,i,j,torch.full_like(b,1)] = get_dih(Ca[b,i], Cb[b,i], Cb[b,j], Ca[b,j])
|
||||
c6d[b,i,j,torch.full_like(b,2)] = get_dih(N[b,i], Ca[b,i], Cb[b,i], Cb[b,j])
|
||||
c6d[b,i,j,torch.full_like(b,3)] = get_ang(Ca[b,i], Cb[b,i], Cb[b,j])
|
||||
|
||||
# fix long-range distances
|
||||
c6d[...,0][c6d[...,0]>=params['DMAX']] = 999.9
|
||||
|
||||
mask = torch.zeros((batch, nres,nres), dtype=xyz.dtype, device=xyz.device)
|
||||
mask[b,i,j] = 1.0
|
||||
return c6d, mask
|
||||
|
||||
def xyz_to_t2d(xyz_t, t0d, params=PARAMS):
|
||||
"""convert template cartesian coordinates into 2d distance
|
||||
and orientation maps
|
||||
|
||||
Parameters
|
||||
----------
|
||||
xyz_t : pytorch tensor of shape [batch,templ,nres,3,3]
|
||||
stores Cartesian coordinates of template backbone N,Ca,C atoms
|
||||
t0d: 0-D template features (HHprob, seqID, similarity) [batch, templ, 3]
|
||||
|
||||
Returns
|
||||
-------
|
||||
t2d : pytorch tensor of shape [batch,nres,nres,1+6+3]
|
||||
stores stacked dist,omega,theta,phi 2D maps
|
||||
"""
|
||||
B, T, L = xyz_t.shape[:3]
|
||||
c6d, mask = xyz_to_c6d(xyz_t.view(B*T,L,3,3), params=params)
|
||||
c6d = c6d.view(B, T, L, L, 4)
|
||||
mask = mask.view(B, T, L, L, 1)
|
||||
#
|
||||
dist = c6d[...,:1]*mask / params['DMAX'] # from 0 to 1 # (B, T, L, L, 1)
|
||||
dist = torch.clamp(dist, 0.0, 1.0)
|
||||
orien = torch.cat((torch.sin(c6d[...,1:]), torch.cos(c6d[...,1:])), dim=-1)*mask # (B, T, L, L, 6)
|
||||
t0d = t0d.unsqueeze(2).unsqueeze(3).expand(-1, -1, L, L, -1)
|
||||
#
|
||||
t2d = torch.cat((dist, orien, t0d), dim=-1)
|
||||
t2d[torch.isnan(t2d)] = 0.0
|
||||
return t2d
|
||||
|
||||
# ============================================================
|
||||
def c6d_to_bins(c6d,params=PARAMS):
|
||||
"""bin 2d distance and orientation maps
|
||||
"""
|
||||
|
||||
dstep = (params['DMAX'] - params['DMIN']) / params['DBINS']
|
||||
astep = 2.0*np.pi / params['ABINS']
|
||||
|
||||
dbins = torch.linspace(params['DMIN']+dstep, params['DMAX'], params['DBINS'],dtype=c6d.dtype,device=c6d.device)
|
||||
ab360 = torch.linspace(-np.pi+astep, np.pi, params['ABINS'],dtype=c6d.dtype,device=c6d.device)
|
||||
ab180 = torch.linspace(astep, np.pi, params['ABINS']//2,dtype=c6d.dtype,device=c6d.device)
|
||||
|
||||
db = torch.bucketize(c6d[...,0].contiguous(),dbins)
|
||||
ob = torch.bucketize(c6d[...,1].contiguous(),ab360)
|
||||
tb = torch.bucketize(c6d[...,2].contiguous(),ab360)
|
||||
pb = torch.bucketize(c6d[...,3].contiguous(),ab180)
|
||||
|
||||
ob[db==params['DBINS']] = params['ABINS']
|
||||
tb[db==params['DBINS']] = params['ABINS']
|
||||
pb[db==params['DBINS']] = params['ABINS']//2
|
||||
|
||||
return torch.stack([db,ob,tb,pb],axis=-1).to(torch.uint8)
|
||||
|
||||
|
||||
# ============================================================
|
||||
def dist_to_bins(dist,params=PARAMS):
|
||||
"""bin 2d distance maps
|
||||
"""
|
||||
|
||||
dstep = (params['DMAX'] - params['DMIN']) / params['DBINS']
|
||||
db = torch.round((dist-params['DMIN']-dstep/2)/dstep)
|
||||
|
||||
db[db<0] = 0
|
||||
db[db>params['DBINS']] = params['DBINS']
|
||||
|
||||
return db.long()
|
||||
|
||||
|
||||
# ============================================================
|
||||
def c6d_to_bins2(c6d,params=PARAMS):
|
||||
"""bin 2d distance and orientation maps
|
||||
"""
|
||||
|
||||
dstep = (params['DMAX'] - params['DMIN']) / params['DBINS']
|
||||
astep = 2.0*np.pi / params['ABINS']
|
||||
|
||||
db = torch.round((c6d[...,0]-params['DMIN']-dstep/2)/dstep)
|
||||
ob = torch.round((c6d[...,1]+np.pi-astep/2)/astep)
|
||||
tb = torch.round((c6d[...,2]+np.pi-astep/2)/astep)
|
||||
pb = torch.round((c6d[...,3]-astep/2)/astep)
|
||||
|
||||
# put all d<dmin into one bin
|
||||
db[db<0] = 0
|
||||
|
||||
# synchronize no-contact bins
|
||||
db[db>params['DBINS']] = params['DBINS']
|
||||
ob[db==params['DBINS']] = params['ABINS']
|
||||
tb[db==params['DBINS']] = params['ABINS']
|
||||
pb[db==params['DBINS']] = params['ABINS']//2
|
||||
|
||||
return torch.stack([db,ob,tb,pb],axis=-1).long()
|
255
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/parsers.py
Normal file
255
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/parsers.py
Normal file
|
@ -0,0 +1,255 @@
|
|||
import numpy as np
|
||||
import scipy
|
||||
import scipy.spatial
|
||||
import string
|
||||
import os,re
|
||||
import random
|
||||
import util
|
||||
import torch
|
||||
from ffindex import *
|
||||
|
||||
to1letter = {
|
||||
"ALA":'A', "ARG":'R', "ASN":'N', "ASP":'D', "CYS":'C',
|
||||
"GLN":'Q', "GLU":'E', "GLY":'G', "HIS":'H', "ILE":'I',
|
||||
"LEU":'L', "LYS":'K', "MET":'M', "PHE":'F', "PRO":'P',
|
||||
"SER":'S', "THR":'T', "TRP":'W', "TYR":'Y', "VAL":'V' }
|
||||
|
||||
# read A3M and convert letters into
|
||||
# integers in the 0..20 range,
|
||||
def parse_a3m(filename):
|
||||
|
||||
msa = []
|
||||
|
||||
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
|
||||
|
||||
# read file line by line
|
||||
for line in open(filename,"r"):
|
||||
|
||||
# skip labels
|
||||
if line[0] == '>':
|
||||
continue
|
||||
|
||||
# remove right whitespaces
|
||||
line = line.rstrip()
|
||||
|
||||
# remove lowercase letters and append to MSA
|
||||
msa.append(line.translate(table))
|
||||
|
||||
# convert letters into numbers
|
||||
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8)
|
||||
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
|
||||
for i in range(alphabet.shape[0]):
|
||||
msa[msa == alphabet[i]] = i
|
||||
|
||||
# treat all unknown characters as gaps
|
||||
msa[msa > 20] = 20
|
||||
|
||||
return msa
|
||||
|
||||
# parse HHsearch output
|
||||
def parse_hhr(filename, ffindex, idmax=105.0):
|
||||
|
||||
# labels present in the database
|
||||
label_set = set([i.name for i in ffindex])
|
||||
|
||||
out = []
|
||||
|
||||
with open(filename, "r") as hhr:
|
||||
|
||||
# read .hhr into a list of lines
|
||||
lines = [s.rstrip() for _,s in enumerate(hhr)]
|
||||
|
||||
# read list of all hits
|
||||
start = lines.index("") + 2
|
||||
stop = lines[start:].index("") + start
|
||||
hits = []
|
||||
for line in lines[start:stop]:
|
||||
|
||||
# ID of the hit
|
||||
#label = re.sub('_','',line[4:10].strip())
|
||||
label = line[4:10].strip()
|
||||
|
||||
# position in the query where the alignment starts
|
||||
qstart = int(line[75:84].strip().split("-")[0])-1
|
||||
|
||||
# position in the template where the alignment starts
|
||||
tstart = int(line[85:94].strip().split("-")[0])-1
|
||||
|
||||
hits.append([label, qstart, tstart, int(line[69:75])])
|
||||
|
||||
# get line numbers where each hit starts
|
||||
start = [i for i,l in enumerate(lines) if l and l[0]==">"] # and l[1:].strip() in label_set]
|
||||
|
||||
# process hits
|
||||
for idx,i in enumerate(start):
|
||||
|
||||
# skip if hit is too short
|
||||
if hits[idx][3] < 10:
|
||||
continue
|
||||
|
||||
# skip if template is not in the database
|
||||
if hits[idx][0] not in label_set:
|
||||
continue
|
||||
|
||||
# get hit statistics
|
||||
p,e,s,_,seqid,sim,_,neff = [float(s) for s in re.sub('[=%]', ' ', lines[i+1]).split()[1::2]]
|
||||
|
||||
# skip too similar hits
|
||||
if seqid > idmax:
|
||||
continue
|
||||
|
||||
query = np.array(list(lines[i+4].split()[3]), dtype='|S1')
|
||||
tmplt = np.array(list(lines[i+8].split()[3]), dtype='|S1')
|
||||
|
||||
simlr = np.array(list(lines[i+6][22:]), dtype='|S1').view(np.uint8)
|
||||
abc = np.array(list(" =-.+|"), dtype='|S1').view(np.uint8)
|
||||
for k in range(abc.shape[0]):
|
||||
simlr[simlr == abc[k]] = k
|
||||
|
||||
confd = np.array(list(lines[i+11][22:]), dtype='|S1').view(np.uint8)
|
||||
abc = np.array(list(" 0123456789"), dtype='|S1').view(np.uint8)
|
||||
for k in range(abc.shape[0]):
|
||||
confd[confd == abc[k]] = k
|
||||
|
||||
qj = np.cumsum(query!=b'-') + hits[idx][1]
|
||||
tj = np.cumsum(tmplt!=b'-') + hits[idx][2]
|
||||
|
||||
# matched positions
|
||||
matches = np.array([[q-1,t-1,s-1,c-1] for q,t,s,c in zip(qj,tj,simlr,confd) if s>0])
|
||||
|
||||
# skip short hits
|
||||
ncol = matches.shape[0]
|
||||
if ncol<10:
|
||||
continue
|
||||
|
||||
# save hit
|
||||
#out.update({hits[idx][0] : [matches,p/100,seqid/100,neff/10]})
|
||||
out.append([hits[idx][0],matches,p/100,seqid/100,sim/10])
|
||||
|
||||
return out
|
||||
|
||||
# read and extract xyz coords of N,Ca,C atoms
|
||||
# from a PDB file
|
||||
def parse_pdb(filename):
|
||||
|
||||
lines = open(filename,'r').readlines()
|
||||
|
||||
N = np.array([[float(l[30:38]), float(l[38:46]), float(l[46:54])]
|
||||
for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="N"])
|
||||
Ca = np.array([[float(l[30:38]), float(l[38:46]), float(l[46:54])]
|
||||
for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"])
|
||||
C = np.array([[float(l[30:38]), float(l[38:46]), float(l[46:54])]
|
||||
for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="C"])
|
||||
|
||||
xyz = np.stack([N,Ca,C], axis=0)
|
||||
|
||||
# indices of residues observed in the structure
|
||||
idx = np.array([int(l[22:26]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"])
|
||||
|
||||
return xyz,idx
|
||||
|
||||
def parse_pdb_lines(lines):
|
||||
|
||||
# indices of residues observed in the structure
|
||||
idx_s = [int(l[22:26]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"]
|
||||
|
||||
# 4 BB + up to 10 SC atoms
|
||||
xyz = np.full((len(idx_s), 14, 3), np.nan, dtype=np.float32)
|
||||
for l in lines:
|
||||
if l[:4] != "ATOM":
|
||||
continue
|
||||
resNo, atom, aa = int(l[22:26]), l[12:16], l[17:20]
|
||||
idx = idx_s.index(resNo)
|
||||
for i_atm, tgtatm in enumerate(util.aa2long[util.aa2num[aa]]):
|
||||
if tgtatm == atom:
|
||||
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
|
||||
break
|
||||
|
||||
# save atom mask
|
||||
mask = np.logical_not(np.isnan(xyz[...,0]))
|
||||
xyz[np.isnan(xyz[...,0])] = 0.0
|
||||
|
||||
return xyz,mask,np.array(idx_s)
|
||||
|
||||
def parse_templates(ffdb, hhr_fn, atab_fn, n_templ=10):
|
||||
|
||||
# process tabulated hhsearch output to get
|
||||
# matched positions and positional scores
|
||||
infile = atab_fn
|
||||
hits = []
|
||||
for l in open(infile, "r").readlines():
|
||||
if l[0]=='>':
|
||||
key = l[1:].split()[0]
|
||||
hits.append([key,[],[]])
|
||||
elif "score" in l or "dssp" in l:
|
||||
continue
|
||||
else:
|
||||
hi = l.split()[:5]+[0.0,0.0,0.0]
|
||||
hits[-1][1].append([int(hi[0]),int(hi[1])])
|
||||
hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])])
|
||||
|
||||
# get per-hit statistics from an .hhr file
|
||||
# (!!! assume that .hhr and .atab have the same hits !!!)
|
||||
# [Probab, E-value, Score, Aligned_cols,
|
||||
# Identities, Similarity, Sum_probs, Template_Neff]
|
||||
lines = open(hhr_fn, "r").readlines()
|
||||
pos = [i+1 for i,l in enumerate(lines) if l[0]=='>']
|
||||
for i,posi in enumerate(pos):
|
||||
hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]])
|
||||
|
||||
# parse templates from FFDB
|
||||
for hi in hits:
|
||||
#if hi[0] not in ffids:
|
||||
# continue
|
||||
entry = get_entry_by_name(hi[0], ffdb.index)
|
||||
if entry == None:
|
||||
continue
|
||||
data = read_entry_lines(entry, ffdb.data)
|
||||
hi += list(parse_pdb_lines(data))
|
||||
|
||||
# process hits
|
||||
counter = 0
|
||||
xyz,qmap,mask,f0d,f1d,ids = [],[],[],[],[],[]
|
||||
for data in hits:
|
||||
if len(data)<7:
|
||||
continue
|
||||
|
||||
qi,ti = np.array(data[1]).T
|
||||
_,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True)
|
||||
ncol = sel1.shape[0]
|
||||
if ncol < 10:
|
||||
continue
|
||||
|
||||
ids.append(data[0])
|
||||
f0d.append(data[3])
|
||||
f1d.append(np.array(data[2])[sel1])
|
||||
xyz.append(data[4][sel2])
|
||||
mask.append(data[5][sel2])
|
||||
qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1))
|
||||
counter += 1
|
||||
|
||||
xyz = np.vstack(xyz).astype(np.float32)
|
||||
qmap = np.vstack(qmap).astype(np.long)
|
||||
f0d = np.vstack(f0d).astype(np.float32)
|
||||
f1d = np.vstack(f1d).astype(np.float32)
|
||||
ids = ids
|
||||
|
||||
return torch.from_numpy(xyz), torch.from_numpy(qmap), \
|
||||
torch.from_numpy(f0d), torch.from_numpy(f1d), ids
|
||||
|
||||
def read_templates(qlen, ffdb, hhr_fn, atab_fn, n_templ=10):
|
||||
xyz_t, qmap, t0d, t1d, ids = parse_templates(ffdb, hhr_fn, atab_fn)
|
||||
npick = min(n_templ, len(ids))
|
||||
sample = torch.arange(npick)
|
||||
#
|
||||
xyz = torch.full((npick, qlen, 3, 3), np.nan).float()
|
||||
f1d = torch.zeros((npick, qlen, 3)).float()
|
||||
f0d = list()
|
||||
#
|
||||
for i, nt in enumerate(sample):
|
||||
sel = torch.where(qmap[:,1] == nt)[0]
|
||||
pos = qmap[sel, 0]
|
||||
xyz[i, pos] = xyz_t[sel, :3]
|
||||
f1d[i, pos] = t1d[sel, :3]
|
||||
f0d.append(torch.stack([t0d[nt,0]/100.0, t0d[nt, 4]/100.0, t0d[nt,5]], dim=-1))
|
||||
return xyz, f1d, torch.stack(f0d, dim=0)
|
|
@ -0,0 +1,284 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from functools import partial
|
||||
|
||||
# Original implementation from https://github.com/lucidrains/performer-pytorch
|
||||
|
||||
# helpers
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def empty(tensor):
|
||||
return tensor.numel() == 0
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def get_module_device(module):
|
||||
return next(module.parameters()).device
|
||||
|
||||
def find_modules(nn_module, type):
|
||||
return [module for module in nn_module.modules() if isinstance(module, type)]
|
||||
|
||||
# kernel functions
|
||||
def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None):
|
||||
b, h, *_ = data.shape
|
||||
|
||||
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
|
||||
|
||||
ratio = (projection_matrix.shape[0] ** -0.5)
|
||||
|
||||
#projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)
|
||||
projection = projection_matrix.unsqueeze(0).repeat(h, 1, 1)
|
||||
projection = projection.unsqueeze(0).repeat(b, 1, 1, 1) # (b,h,j,d)
|
||||
projection = projection.type_as(data)
|
||||
|
||||
data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)
|
||||
|
||||
diag_data = data ** 2
|
||||
diag_data = torch.sum(diag_data, dim=-1)
|
||||
diag_data = (diag_data / 2.0) * (data_normalizer ** 2)
|
||||
diag_data = diag_data.unsqueeze(dim=-1)
|
||||
|
||||
if is_query:
|
||||
data_dash = ratio * (
|
||||
torch.exp(data_dash - diag_data -
|
||||
torch.max(data_dash, dim=-1, keepdim=True).values) + eps)
|
||||
else:
|
||||
data_dash = ratio * (
|
||||
torch.exp(data_dash - diag_data - torch.max(data_dash)) + eps)
|
||||
|
||||
return data_dash.type_as(data)
|
||||
|
||||
def generalized_kernel(data, *, projection_matrix, kernel_fn = nn.ReLU(inplace=True), kernel_epsilon = 0.001, normalize_data = True, device = None):
|
||||
b, h, *_ = data.shape
|
||||
|
||||
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
|
||||
|
||||
if projection_matrix is None:
|
||||
return kernel_fn(data_normalizer * data) + kernel_epsilon
|
||||
|
||||
data = data_normalizer*data
|
||||
data = torch.matmul(data, projection_matrix.T)
|
||||
data = kernel_fn(data) + kernel_epsilon
|
||||
return data.type_as(data)
|
||||
|
||||
def orthogonal_matrix_chunk(cols, qr_uniform_q = False, device = None):
|
||||
unstructured_block = torch.randn((cols, cols), device = device)
|
||||
q, r = torch.linalg.qr(unstructured_block.cpu(), 'reduced')
|
||||
q, r = map(lambda t: t.to(device), (q, r))
|
||||
|
||||
# proposed by @Parskatt
|
||||
# to make sure Q is uniform https://arxiv.org/pdf/math-ph/0609050.pdf
|
||||
if qr_uniform_q:
|
||||
d = torch.diag(r, 0)
|
||||
q *= d.sign()
|
||||
return q.t()
|
||||
|
||||
def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, qr_uniform_q = False, device = None):
|
||||
nb_full_blocks = int(nb_rows / nb_columns)
|
||||
|
||||
block_list = []
|
||||
|
||||
for _ in range(nb_full_blocks):
|
||||
q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device)
|
||||
block_list.append(q)
|
||||
|
||||
remaining_rows = nb_rows - nb_full_blocks * nb_columns
|
||||
if remaining_rows > 0:
|
||||
q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device)
|
||||
block_list.append(q[:remaining_rows])
|
||||
|
||||
final_matrix = torch.cat(block_list)
|
||||
|
||||
if scaling == 0:
|
||||
multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1)
|
||||
elif scaling == 1:
|
||||
multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device)
|
||||
else:
|
||||
raise ValueError(f'Invalid scaling {scaling}')
|
||||
|
||||
return torch.diag(multiplier) @ final_matrix
|
||||
|
||||
# linear attention classes with softmax kernel
|
||||
|
||||
# non-causal linear attention
|
||||
def linear_attention(q, k, v):
|
||||
L = k.shape[-2]
|
||||
D_inv = 1. / torch.einsum('...nd,...d->...n', q, k.mean(dim=-2))
|
||||
context = torch.einsum('...nd,...ne->...de', k/float(L), v)
|
||||
del k, v
|
||||
out = torch.einsum('...n,...nd->...nd', D_inv, q)
|
||||
del D_inv, q
|
||||
out = torch.einsum('...nd,...de->...ne', out, context)
|
||||
return out
|
||||
|
||||
class FastAttention(nn.Module):
|
||||
def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, generalized_attention = False, kernel_fn = nn.ReLU(inplace=True), qr_uniform_q = False, no_projection = False):
|
||||
super().__init__()
|
||||
nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
|
||||
|
||||
self.dim_heads = dim_heads
|
||||
self.nb_features = nb_features
|
||||
self.ortho_scaling = ortho_scaling
|
||||
|
||||
if not no_projection:
|
||||
self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling, qr_uniform_q = qr_uniform_q)
|
||||
projection_matrix = self.create_projection()
|
||||
self.register_buffer('projection_matrix', projection_matrix)
|
||||
|
||||
self.generalized_attention = generalized_attention
|
||||
self.kernel_fn = kernel_fn
|
||||
|
||||
# if this is turned on, no projection will be used
|
||||
# queries and keys will be softmax-ed as in the original efficient attention paper
|
||||
self.no_projection = no_projection
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def redraw_projection_matrix(self, device):
|
||||
projections = self.create_projection(device = device)
|
||||
self.projection_matrix.copy_(projections)
|
||||
del projections
|
||||
|
||||
def forward(self, q, k, v):
|
||||
device = q.device
|
||||
|
||||
if self.no_projection:
|
||||
q = q.softmax(dim = -1)
|
||||
k.softmax(dim = -2)
|
||||
|
||||
elif self.generalized_attention:
|
||||
create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device)
|
||||
q, k = map(create_kernel, (q, k))
|
||||
|
||||
else:
|
||||
create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device)
|
||||
q = create_kernel(q, is_query = True)
|
||||
k = create_kernel(k, is_query = False)
|
||||
|
||||
attn_fn = linear_attention
|
||||
out = attn_fn(q, k, v)
|
||||
return out
|
||||
|
||||
# classes
|
||||
class ReZero(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.g = nn.Parameter(torch.tensor(1e-3))
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(x, **kwargs) * self.g
|
||||
|
||||
class PreScaleNorm(nn.Module):
|
||||
def __init__(self, dim, fn, eps=1e-5):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.g = nn.Parameter(torch.ones(1))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
n = torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps)
|
||||
x = x / n * self.g
|
||||
return self.fn(x, **kwargs)
|
||||
|
||||
class PreLayerNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class Chunk(nn.Module):
|
||||
def __init__(self, chunks, fn, along_dim = -1):
|
||||
super().__init__()
|
||||
self.dim = along_dim
|
||||
self.chunks = chunks
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
if self.chunks == 1:
|
||||
return self.fn(x, **kwargs)
|
||||
chunks = x.chunk(self.chunks, dim = self.dim)
|
||||
return torch.cat([self.fn(c, **kwargs) for c in chunks], dim = self.dim)
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim, k_dim=None, heads = 8, local_heads = 0, local_window_size = 256, nb_features = None, feature_redraw_interval = 1000, generalized_attention = False, kernel_fn = nn.ReLU(inplace=True), qr_uniform_q = False, dropout = 0., no_projection = False):
|
||||
super().__init__()
|
||||
assert dim % heads == 0, 'dimension must be divisible by number of heads'
|
||||
dim_head = dim // heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
if k_dim == None:
|
||||
k_dim = dim
|
||||
|
||||
self.fast_attention = FastAttention(dim_head, nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q, no_projection = no_projection)
|
||||
|
||||
self.heads = heads
|
||||
self.dim = dim
|
||||
|
||||
self.to_query = nn.Linear(dim, inner_dim)
|
||||
self.to_key = nn.Linear(k_dim, inner_dim)
|
||||
self.to_value = nn.Linear(k_dim, inner_dim)
|
||||
self.to_out = nn.Linear(inner_dim, dim)
|
||||
self.dropout = nn.Dropout(dropout, inplace=True)
|
||||
|
||||
self.feature_redraw_interval = feature_redraw_interval
|
||||
self.register_buffer("calls_since_last_redraw", torch.tensor(0))
|
||||
|
||||
self.max_tokens = 2**16
|
||||
|
||||
def check_redraw_projections(self):
|
||||
if not self.training:
|
||||
return
|
||||
|
||||
if exists(self.feature_redraw_interval) and self.calls_since_last_redraw >= self.feature_redraw_interval:
|
||||
device = get_module_device(self)
|
||||
|
||||
fast_attentions = find_modules(self, FastAttention)
|
||||
for fast_attention in fast_attentions:
|
||||
fast_attention.redraw_projection_matrix(device)
|
||||
|
||||
self.calls_since_last_redraw.zero_()
|
||||
return
|
||||
|
||||
self.calls_since_last_redraw += 1
|
||||
|
||||
def _batched_forward(self, q, k, v):
|
||||
b1, h, n1 = q.shape[:3]
|
||||
out = torch.empty((b1, h, n1, self.dim//h), dtype=q.dtype, device=q.device)
|
||||
shift = self.max_tokens // n1
|
||||
for i_b in range(0, b1, shift):
|
||||
start = i_b
|
||||
end = min(i_b+shift, b1)
|
||||
out[start:end] = self.fast_attention(q[start:end], k[start:end], v[start:end])
|
||||
return out
|
||||
|
||||
def forward(self, query, key, value, **kwargs):
|
||||
self.check_redraw_projections()
|
||||
|
||||
b1, n1, _, h = *query.shape, self.heads
|
||||
b2, n2, _, h = *key.shape, self.heads
|
||||
|
||||
q = self.to_query(query)
|
||||
k = self.to_key(key)
|
||||
v = self.to_value(value)
|
||||
|
||||
q = q.reshape(b1, n1, h, -1).permute(0,2,1,3) # (b, h, n, d)
|
||||
k = k.reshape(b2, n2, h, -1).permute(0,2,1,3)
|
||||
v = v.reshape(b2, n2, h, -1).permute(0,2,1,3)
|
||||
|
||||
if b1*n1 > self.max_tokens or b2*n2 > self.max_tokens:
|
||||
out = self._batched_forward(q, k, v)
|
||||
else:
|
||||
out = self.fast_attention(q, k, v)
|
||||
|
||||
out = out.permute(0,2,1,3).reshape(b1,n1,-1)
|
||||
out = self.to_out(out)
|
||||
return self.dropout(out)
|
||||
|
317
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/predict_complex.py
Normal file
317
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/predict_complex.py
Normal file
|
@ -0,0 +1,317 @@
|
|||
import sys, os
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils import data
|
||||
from parsers import parse_a3m, read_templates
|
||||
from RoseTTAFoldModel import RoseTTAFoldModule_e2e
|
||||
import util
|
||||
from collections import namedtuple
|
||||
from ffindex import *
|
||||
from kinematics import xyz_to_c6d, c6d_to_bins2, xyz_to_t2d
|
||||
from trFold import TRFold
|
||||
|
||||
script_dir = '/'.join(os.path.dirname(os.path.realpath(__file__)).split('/')[:-1])
|
||||
NBIN = [37, 37, 37, 19]
|
||||
|
||||
MODEL_PARAM ={
|
||||
"n_module" : 8,
|
||||
"n_module_str" : 4,
|
||||
"n_module_ref" : 4,
|
||||
"n_layer" : 1,
|
||||
"d_msa" : 384 ,
|
||||
"d_pair" : 288,
|
||||
"d_templ" : 64,
|
||||
"n_head_msa" : 12,
|
||||
"n_head_pair" : 8,
|
||||
"n_head_templ" : 4,
|
||||
"d_hidden" : 64,
|
||||
"r_ff" : 4,
|
||||
"n_resblock" : 1,
|
||||
"p_drop" : 0.1,
|
||||
"use_templ" : True,
|
||||
"performer_N_opts": {"nb_features": 64},
|
||||
"performer_L_opts": {"nb_features": 64}
|
||||
}
|
||||
|
||||
SE3_param = {
|
||||
"num_layers" : 2,
|
||||
"num_channels" : 16,
|
||||
"num_degrees" : 2,
|
||||
"l0_in_features": 32,
|
||||
"l0_out_features": 8,
|
||||
"l1_in_features": 3,
|
||||
"l1_out_features": 3,
|
||||
"num_edge_features": 32,
|
||||
"div": 2,
|
||||
"n_heads": 4
|
||||
}
|
||||
|
||||
REF_param = {
|
||||
"num_layers" : 3,
|
||||
"num_channels" : 32,
|
||||
"num_degrees" : 3,
|
||||
"l0_in_features": 32,
|
||||
"l0_out_features": 8,
|
||||
"l1_in_features": 3,
|
||||
"l1_out_features": 3,
|
||||
"num_edge_features": 32,
|
||||
"div": 4,
|
||||
"n_heads": 4
|
||||
}
|
||||
MODEL_PARAM['SE3_param'] = SE3_param
|
||||
MODEL_PARAM['REF_param'] = REF_param
|
||||
|
||||
# params for the folding protocol
|
||||
fold_params = {
|
||||
"SG7" : np.array([[[-2,3,6,7,6,3,-2]]])/21,
|
||||
"SG9" : np.array([[[-21,14,39,54,59,54,39,14,-21]]])/231,
|
||||
"DCUT" : 19.5,
|
||||
"ALPHA" : 1.57,
|
||||
|
||||
# TODO: add Cb to the motif
|
||||
"NCAC" : np.array([[-0.676, -1.294, 0. ],
|
||||
[ 0. , 0. , 0. ],
|
||||
[ 1.5 , -0.174, 0. ]], dtype=np.float32),
|
||||
"CLASH" : 2.0,
|
||||
"PCUT" : 0.5,
|
||||
"DSTEP" : 0.5,
|
||||
"ASTEP" : np.deg2rad(10.0),
|
||||
"XYZRAD" : 7.5,
|
||||
"WANG" : 0.1,
|
||||
"WCST" : 0.1
|
||||
}
|
||||
|
||||
fold_params["SG"] = fold_params["SG9"]
|
||||
|
||||
class Predictor():
|
||||
def __init__(self, model_dir=None, use_cpu=False):
|
||||
if model_dir == None:
|
||||
self.model_dir = "%s/weights"%(script_dir)
|
||||
else:
|
||||
self.model_dir = model_dir
|
||||
#
|
||||
# define model name
|
||||
self.model_name = "RoseTTAFold"
|
||||
if torch.cuda.is_available() and (not use_cpu):
|
||||
self.device = torch.device("cuda")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
self.active_fn = nn.Softmax(dim=1)
|
||||
|
||||
# define model & load model
|
||||
self.model = RoseTTAFoldModule_e2e(**MODEL_PARAM).to(self.device)
|
||||
could_load = self.load_model(self.model_name)
|
||||
if not could_load:
|
||||
print ("ERROR: failed to load model")
|
||||
sys.exit()
|
||||
|
||||
def load_model(self, model_name, suffix='e2e'):
|
||||
chk_fn = "%s/%s_%s.pt"%(self.model_dir, model_name, suffix)
|
||||
if not os.path.exists(chk_fn):
|
||||
return False
|
||||
checkpoint = torch.load(chk_fn, map_location=self.device)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'], strict=True)
|
||||
return True
|
||||
|
||||
def predict(self, a3m_fn, out_prefix, Ls, templ_npz=None, window=1000, shift=100):
|
||||
msa = parse_a3m(a3m_fn)
|
||||
N, L = msa.shape
|
||||
#
|
||||
if templ_npz != None:
|
||||
templ = np.load(templ_npz)
|
||||
xyz_t = torch.from_numpy(templ["xyz_t"])
|
||||
t1d = torch.from_numpy(templ["t1d"])
|
||||
t0d = torch.from_numpy(templ["t0d"])
|
||||
else:
|
||||
xyz_t = torch.full((1, L, 3, 3), np.nan).float()
|
||||
t1d = torch.zeros((1, L, 3)).float()
|
||||
t0d = torch.zeros((1,3)).float()
|
||||
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
#
|
||||
msa = torch.tensor(msa).long().view(1, -1, L)
|
||||
idx_pdb_orig = torch.arange(L).long().view(1, L)
|
||||
idx_pdb = torch.arange(L).long().view(1, L)
|
||||
L_prev = 0
|
||||
for L_i in Ls[:-1]:
|
||||
idx_pdb[:,L_prev+L_i:] += 500 # it was 200 originally.
|
||||
L_prev += L_i
|
||||
seq = msa[:,0]
|
||||
#
|
||||
# template features
|
||||
xyz_t = xyz_t.float().unsqueeze(0)
|
||||
t1d = t1d.float().unsqueeze(0)
|
||||
t0d = t0d.float().unsqueeze(0)
|
||||
t2d = xyz_to_t2d(xyz_t, t0d)
|
||||
#
|
||||
# do cropped prediction
|
||||
if L > window*2:
|
||||
prob_s = [np.zeros((L,L,NBIN[i]), dtype=np.float32) for i in range(4)]
|
||||
count_1d = np.zeros((L,), dtype=np.float32)
|
||||
count_2d = np.zeros((L,L), dtype=np.float32)
|
||||
node_s = np.zeros((L,MODEL_PARAM['d_msa']), dtype=np.float32)
|
||||
#
|
||||
grids = np.arange(0, L-window+shift, shift)
|
||||
ngrids = grids.shape[0]
|
||||
print("ngrid: ", ngrids)
|
||||
print("grids: ", grids)
|
||||
print("windows: ", window)
|
||||
|
||||
for i in range(ngrids):
|
||||
for j in range(i, ngrids):
|
||||
start_1 = grids[i]
|
||||
end_1 = min(grids[i]+window, L)
|
||||
start_2 = grids[j]
|
||||
end_2 = min(grids[j]+window, L)
|
||||
sel = np.zeros((L)).astype(np.bool)
|
||||
sel[start_1:end_1] = True
|
||||
sel[start_2:end_2] = True
|
||||
|
||||
input_msa = msa[:,:,sel]
|
||||
mask = torch.sum(input_msa==20, dim=-1) < 0.5*sel.sum() # remove too gappy sequences
|
||||
input_msa = input_msa[mask].unsqueeze(0)
|
||||
input_msa = input_msa[:,:1000].to(self.device)
|
||||
input_idx = idx_pdb[:,sel].to(self.device)
|
||||
input_idx_orig = idx_pdb_orig[:,sel]
|
||||
input_seq = input_msa[:,0].to(self.device)
|
||||
#
|
||||
# Select template
|
||||
input_t1d = t1d[:,:,sel].to(self.device) # (B, T, L, 3)
|
||||
input_t2d = t2d[:,:,sel][:,:,:,sel].to(self.device)
|
||||
#
|
||||
print ("running crop: %d-%d/%d-%d"%(start_1, end_1, start_2, end_2), input_msa.shape)
|
||||
with torch.cuda.amp.autocast():
|
||||
logit_s, node, init_crds, pred_lddt = self.model(input_msa, input_seq, input_idx, t1d=input_t1d, t2d=input_t2d, return_raw=True)
|
||||
#
|
||||
# Not sure How can we merge init_crds.....
|
||||
pred_lddt = torch.clamp(pred_lddt, 0.0, 1.0)
|
||||
sub_idx = input_idx_orig[0]
|
||||
sub_idx_2d = np.ix_(sub_idx, sub_idx)
|
||||
count_2d[sub_idx_2d] += 1.0
|
||||
count_1d[sub_idx] += 1.0
|
||||
node_s[sub_idx] += node[0].cpu().numpy()
|
||||
for i_logit, logit in enumerate(logit_s):
|
||||
prob = self.active_fn(logit.float()) # calculate distogram
|
||||
prob = prob.squeeze(0).permute(1,2,0).cpu().numpy()
|
||||
prob_s[i_logit][sub_idx_2d] += prob
|
||||
del logit_s, node
|
||||
#
|
||||
for i in range(4):
|
||||
prob_s[i] = prob_s[i] / count_2d[:,:,None]
|
||||
prob_in = np.concatenate(prob_s, axis=-1)
|
||||
node_s = node_s / count_1d[:, None]
|
||||
#
|
||||
node_s = torch.tensor(node_s).to(self.device).unsqueeze(0)
|
||||
seq = msa[:,0].to(self.device)
|
||||
idx_pdb = idx_pdb.to(self.device)
|
||||
prob_in = torch.tensor(prob_in).to(self.device).unsqueeze(0)
|
||||
with torch.cuda.amp.autocast():
|
||||
xyz, lddt = self.model(node_s, seq, idx_pdb, prob_s=prob_in, refine_only=True)
|
||||
print (lddt.mean())
|
||||
else:
|
||||
msa = msa[:,:1000].to(self.device)
|
||||
seq = msa[:,0]
|
||||
idx_pdb = idx_pdb.to(self.device)
|
||||
t1d = t1d[:,:10].to(self.device)
|
||||
t2d = t2d[:,:10].to(self.device)
|
||||
with torch.cuda.amp.autocast():
|
||||
logit_s, _, xyz, lddt = self.model(msa, seq, idx_pdb, t1d=t1d, t2d=t2d)
|
||||
print (lddt.mean())
|
||||
prob_s = list()
|
||||
for logit in logit_s:
|
||||
prob = self.active_fn(logit.float()) # distogram
|
||||
prob = prob.reshape(-1, L, L).permute(1,2,0).cpu().numpy()
|
||||
prob_s.append(prob)
|
||||
|
||||
np.savez_compressed("%s.npz"%out_prefix, dist=prob_s[0].astype(np.float16), \
|
||||
omega=prob_s[1].astype(np.float16),\
|
||||
theta=prob_s[2].astype(np.float16),\
|
||||
phi=prob_s[3].astype(np.float16))
|
||||
|
||||
# run TRFold
|
||||
prob_trF = list()
|
||||
for prob in prob_s:
|
||||
prob = torch.tensor(prob).permute(2,0,1).to(self.device)
|
||||
prob += 1e-8
|
||||
prob = prob / torch.sum(prob, dim=0)[None]
|
||||
prob_trF.append(prob)
|
||||
xyz = xyz[0, :, 1]
|
||||
TRF = TRFold(prob_trF, fold_params)
|
||||
xyz = TRF.fold(xyz, batch=15, lr=0.1, nsteps=200)
|
||||
print (xyz.shape, lddt[0].shape, seq[0].shape)
|
||||
self.write_pdb(seq[0], xyz, Ls, Bfacts=lddt[0], prefix=out_prefix)
|
||||
|
||||
def write_pdb(self, seq, atoms, Ls, Bfacts=None, prefix=None):
|
||||
chainIDs = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
L = len(seq)
|
||||
filename = "%s.pdb"%prefix
|
||||
ctr = 1
|
||||
with open(filename, 'wt') as f:
|
||||
if Bfacts == None:
|
||||
Bfacts = np.zeros(L)
|
||||
else:
|
||||
Bfacts = torch.clamp( Bfacts, 0, 1)
|
||||
|
||||
for i,s in enumerate(seq):
|
||||
if (len(atoms.shape)==2):
|
||||
resNo = i+1
|
||||
chain = "A"
|
||||
for i_chain in range(len(Ls)-1,0,-1):
|
||||
tot_res = sum(Ls[:i_chain])
|
||||
if i+1 > tot_res:
|
||||
chain = chainIDs[i_chain]
|
||||
resNo = i+1 - tot_res
|
||||
break
|
||||
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
|
||||
"ATOM", ctr, " CA ", util.num2aa[s],
|
||||
chain, resNo, atoms[i,0], atoms[i,1], atoms[i,2],
|
||||
1.0, Bfacts[i] ) )
|
||||
ctr += 1
|
||||
|
||||
elif atoms.shape[1]==3:
|
||||
resNo = i+1
|
||||
chain = "A"
|
||||
for i_chain in range(len(Ls)-1,0,-1):
|
||||
tot_res = sum(Ls[:i_chain])
|
||||
if i+1 > tot_res:
|
||||
chain = chainIDs[i_chain]
|
||||
resNo = i+1 - tot_res
|
||||
break
|
||||
for j,atm_j in enumerate((" N "," CA "," C ")):
|
||||
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
|
||||
"ATOM", ctr, atm_j, util.num2aa[s],
|
||||
chain, resNo, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2],
|
||||
1.0, Bfacts[i] ) )
|
||||
ctr += 1
|
||||
|
||||
def get_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
|
||||
parser.add_argument("-m", dest="model_dir", default="%s/weights"%(script_dir),
|
||||
help="Path to pre-trained network weights [%s/weights]"%script_dir)
|
||||
parser.add_argument("-i", dest="a3m_fn", required=True,
|
||||
help="Input multiple sequence alignments (in a3m format)")
|
||||
parser.add_argument("-o", dest="out_prefix", required=True,
|
||||
help="Prefix for output file. The output files will be [out_prefix].npz and [out_prefix].pdb")
|
||||
parser.add_argument("-Ls", dest="Ls", required=True, nargs="+", type=int,
|
||||
help="The length of the each subunit (e.g. 220 400)")
|
||||
parser.add_argument("--templ_npz", default=None,
|
||||
help='''npz file containing complex template information (xyz_t, t1d, t0d). If not provided, zero matrices will be given as templates
|
||||
- xyz_t: N, CA, C coordinates of complex templates (T, L, 3, 3) For the unaligned region, it should be NaN
|
||||
- t1d: 1-D features from HHsearch results (score, SS, probab column from atab file) (T, L, 3). For the unaligned region, it should be zeros
|
||||
- t0d: 0-D features from HHsearch (Probability/100.0, Ideintities/100.0, Similarity fro hhr file) (T, 3)''')
|
||||
parser.add_argument("--cpu", dest='use_cpu', default=False, action='store_true')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
if not os.path.exists("%s.npz"%args.out_prefix):
|
||||
pred = Predictor(model_dir=args.model_dir, use_cpu=args.use_cpu)
|
||||
pred.predict(args.a3m_fn, args.out_prefix, args.Ls, templ_npz=args.templ_npz)
|
324
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/predict_e2e.py
Normal file
324
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/predict_e2e.py
Normal file
|
@ -0,0 +1,324 @@
|
|||
import sys, os
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils import data
|
||||
from parsers import parse_a3m, read_templates
|
||||
from RoseTTAFoldModel import RoseTTAFoldModule_e2e
|
||||
import util
|
||||
from collections import namedtuple
|
||||
from ffindex import *
|
||||
from kinematics import xyz_to_c6d, c6d_to_bins2, xyz_to_t2d
|
||||
from trFold import TRFold
|
||||
|
||||
script_dir = '/'.join(os.path.dirname(os.path.realpath(__file__)).split('/')[:-1])
|
||||
|
||||
NBIN = [37, 37, 37, 19]
|
||||
|
||||
MODEL_PARAM ={
|
||||
"n_module" : 8,
|
||||
"n_module_str" : 4,
|
||||
"n_module_ref" : 4,
|
||||
"n_layer" : 1,
|
||||
"d_msa" : 384 ,
|
||||
"d_pair" : 288,
|
||||
"d_templ" : 64,
|
||||
"n_head_msa" : 12,
|
||||
"n_head_pair" : 8,
|
||||
"n_head_templ" : 4,
|
||||
"d_hidden" : 64,
|
||||
"r_ff" : 4,
|
||||
"n_resblock" : 1,
|
||||
"p_drop" : 0.0,
|
||||
"use_templ" : True,
|
||||
"performer_N_opts": {"nb_features": 64},
|
||||
"performer_L_opts": {"nb_features": 64}
|
||||
}
|
||||
|
||||
SE3_param = {
|
||||
"num_layers" : 2,
|
||||
"num_channels" : 16,
|
||||
"num_degrees" : 2,
|
||||
"l0_in_features": 32,
|
||||
"l0_out_features": 8,
|
||||
"l1_in_features": 3,
|
||||
"l1_out_features": 3,
|
||||
"num_edge_features": 32,
|
||||
"div": 2,
|
||||
"n_heads": 4
|
||||
}
|
||||
|
||||
REF_param = {
|
||||
"num_layers" : 3,
|
||||
"num_channels" : 32,
|
||||
"num_degrees" : 3,
|
||||
"l0_in_features": 32,
|
||||
"l0_out_features": 8,
|
||||
"l1_in_features": 3,
|
||||
"l1_out_features": 3,
|
||||
"num_edge_features": 32,
|
||||
"div": 4,
|
||||
"n_heads": 4
|
||||
}
|
||||
MODEL_PARAM['SE3_param'] = SE3_param
|
||||
MODEL_PARAM['REF_param'] = REF_param
|
||||
|
||||
# params for the folding protocol
|
||||
fold_params = {
|
||||
"SG7" : np.array([[[-2,3,6,7,6,3,-2]]])/21,
|
||||
"SG9" : np.array([[[-21,14,39,54,59,54,39,14,-21]]])/231,
|
||||
"DCUT" : 19.5,
|
||||
"ALPHA" : 1.57,
|
||||
|
||||
# TODO: add Cb to the motif
|
||||
"NCAC" : np.array([[-0.676, -1.294, 0. ],
|
||||
[ 0. , 0. , 0. ],
|
||||
[ 1.5 , -0.174, 0. ]], dtype=np.float32),
|
||||
"CLASH" : 2.0,
|
||||
"PCUT" : 0.5,
|
||||
"DSTEP" : 0.5,
|
||||
"ASTEP" : np.deg2rad(10.0),
|
||||
"XYZRAD" : 7.5,
|
||||
"WANG" : 0.1,
|
||||
"WCST" : 0.1
|
||||
}
|
||||
|
||||
fold_params["SG"] = fold_params["SG9"]
|
||||
|
||||
class Predictor():
|
||||
def __init__(self, model_dir=None, use_cpu=False):
|
||||
if model_dir == None:
|
||||
self.model_dir = "%s/models"%(os.path.dirname(os.path.realpath(__file__)))
|
||||
else:
|
||||
self.model_dir = model_dir
|
||||
#
|
||||
# define model name
|
||||
self.model_name = "RoseTTAFold"
|
||||
if torch.cuda.is_available() and (not use_cpu):
|
||||
self.device = torch.device("cuda")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
self.active_fn = nn.Softmax(dim=1)
|
||||
|
||||
# define model & load model
|
||||
self.model = RoseTTAFoldModule_e2e(**MODEL_PARAM).to(self.device)
|
||||
|
||||
def load_model(self, model_name, suffix='e2e'):
|
||||
chk_fn = "%s/%s_%s.pt"%(self.model_dir, model_name, suffix)
|
||||
if not os.path.exists(chk_fn):
|
||||
return False
|
||||
checkpoint = torch.load(chk_fn, map_location=self.device)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'], strict=True)
|
||||
return True
|
||||
|
||||
def predict(self, a3m_fn, out_prefix, hhr_fn=None, atab_fn=None, window=150, shift=75):
|
||||
msa = parse_a3m(a3m_fn)
|
||||
N, L = msa.shape
|
||||
#
|
||||
if hhr_fn != None:
|
||||
xyz_t, t1d, t0d = read_templates(L, ffdb, hhr_fn, atab_fn, n_templ=10)
|
||||
else:
|
||||
xyz_t = torch.full((1, L, 3, 3), np.nan).float()
|
||||
t1d = torch.zeros((1, L, 3)).float()
|
||||
t0d = torch.zeros((1,3)).float()
|
||||
#
|
||||
msa = torch.tensor(msa).long().view(1, -1, L)
|
||||
idx_pdb = torch.arange(L).long().view(1, L)
|
||||
seq = msa[:,0]
|
||||
#
|
||||
# template features
|
||||
xyz_t = xyz_t.float().unsqueeze(0)
|
||||
t1d = t1d.float().unsqueeze(0)
|
||||
t0d = t0d.float().unsqueeze(0)
|
||||
t2d = xyz_to_t2d(xyz_t, t0d)
|
||||
|
||||
could_load = self.load_model(self.model_name, suffix="e2e")
|
||||
if not could_load:
|
||||
print ("ERROR: failed to load model")
|
||||
sys.exit()
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
# do cropped prediction if protein is too big
|
||||
if L > window*2:
|
||||
prob_s = [np.zeros((L,L,NBIN[i]), dtype=np.float32) for i in range(4)]
|
||||
count_1d = np.zeros((L,), dtype=np.float32)
|
||||
count_2d = np.zeros((L,L), dtype=np.float32)
|
||||
node_s = np.zeros((L,MODEL_PARAM['d_msa']), dtype=np.float32)
|
||||
#
|
||||
grids = np.arange(0, L-window+shift, shift)
|
||||
ngrids = grids.shape[0]
|
||||
print("ngrid: ", ngrids)
|
||||
print("grids: ", grids)
|
||||
print("windows: ", window)
|
||||
|
||||
for i in range(ngrids):
|
||||
for j in range(i, ngrids):
|
||||
start_1 = grids[i]
|
||||
end_1 = min(grids[i]+window, L)
|
||||
start_2 = grids[j]
|
||||
end_2 = min(grids[j]+window, L)
|
||||
sel = np.zeros((L)).astype(np.bool)
|
||||
sel[start_1:end_1] = True
|
||||
sel[start_2:end_2] = True
|
||||
|
||||
input_msa = msa[:,:,sel]
|
||||
mask = torch.sum(input_msa==20, dim=-1) < 0.5*sel.sum() # remove too gappy sequences
|
||||
input_msa = input_msa[mask].unsqueeze(0)
|
||||
input_msa = input_msa[:,:1000].to(self.device)
|
||||
input_idx = idx_pdb[:,sel].to(self.device)
|
||||
input_seq = input_msa[:,0].to(self.device)
|
||||
#
|
||||
# Select template
|
||||
input_t1d = t1d[:,:,sel].to(self.device) # (B, T, L, 3)
|
||||
input_t2d = t2d[:,:,sel][:,:,:,sel].to(self.device)
|
||||
#
|
||||
print ("running crop: %d-%d/%d-%d"%(start_1, end_1, start_2, end_2), input_msa.shape)
|
||||
with torch.cuda.amp.autocast():
|
||||
logit_s, node, init_crds, pred_lddt = self.model(input_msa, input_seq, input_idx, t1d=input_t1d, t2d=input_t2d, return_raw=True)
|
||||
#
|
||||
# Not sure How can we merge init_crds.....
|
||||
sub_idx = input_idx[0].cpu()
|
||||
sub_idx_2d = np.ix_(sub_idx, sub_idx)
|
||||
count_2d[sub_idx_2d] += 1.0
|
||||
count_1d[sub_idx] += 1.0
|
||||
node_s[sub_idx] += node[0].cpu().numpy()
|
||||
for i_logit, logit in enumerate(logit_s):
|
||||
prob = self.active_fn(logit.float()) # calculate distogram
|
||||
prob = prob.squeeze(0).permute(1,2,0).cpu().numpy()
|
||||
prob_s[i_logit][sub_idx_2d] += prob
|
||||
del logit_s, node
|
||||
#
|
||||
# combine all crops
|
||||
for i in range(4):
|
||||
prob_s[i] = prob_s[i] / count_2d[:,:,None]
|
||||
prob_in = np.concatenate(prob_s, axis=-1)
|
||||
node_s = node_s / count_1d[:, None]
|
||||
#
|
||||
# Do iterative refinement using SE(3)-Transformers
|
||||
# clear cache memory
|
||||
torch.cuda.empty_cache()
|
||||
#
|
||||
node_s = torch.tensor(node_s).to(self.device).unsqueeze(0)
|
||||
seq = msa[:,0].to(self.device)
|
||||
idx_pdb = idx_pdb.to(self.device)
|
||||
prob_in = torch.tensor(prob_in).to(self.device).unsqueeze(0)
|
||||
with torch.cuda.amp.autocast():
|
||||
xyz, lddt = self.model(node_s, seq, idx_pdb, prob_s=prob_in, refine_only=True)
|
||||
else:
|
||||
msa = msa[:,:1000].to(self.device)
|
||||
seq = msa[:,0]
|
||||
idx_pdb = idx_pdb.to(self.device)
|
||||
t1d = t1d[:,:10].to(self.device)
|
||||
t2d = t2d[:,:10].to(self.device)
|
||||
with torch.cuda.amp.autocast():
|
||||
logit_s, _, xyz, lddt = self.model(msa, seq, idx_pdb, t1d=t1d, t2d=t2d)
|
||||
prob_s = list()
|
||||
for logit in logit_s:
|
||||
prob = self.active_fn(logit.float()) # distogram
|
||||
prob = prob.reshape(-1, L, L).permute(1,2,0).cpu().numpy()
|
||||
prob_s.append(prob)
|
||||
|
||||
np.savez_compressed("%s.npz"%(out_prefix), dist=prob_s[0].astype(np.float16), \
|
||||
omega=prob_s[1].astype(np.float16),\
|
||||
theta=prob_s[2].astype(np.float16),\
|
||||
phi=prob_s[3].astype(np.float16))
|
||||
|
||||
self.write_pdb(seq[0], xyz[0], idx_pdb[0], Bfacts=lddt[0], prefix="%s_init"%(out_prefix))
|
||||
|
||||
# run TRFold
|
||||
prob_trF = list()
|
||||
for prob in prob_s:
|
||||
prob = torch.tensor(prob).permute(2,0,1).to(self.device)
|
||||
prob += 1e-8
|
||||
prob = prob / torch.sum(prob, dim=0)[None]
|
||||
prob_trF.append(prob)
|
||||
xyz = xyz[0, :, 1]
|
||||
TRF = TRFold(prob_trF, fold_params)
|
||||
xyz = TRF.fold(xyz, batch=15, lr=0.1, nsteps=200)
|
||||
xyz = xyz.detach().cpu().numpy()
|
||||
# add O and Cb
|
||||
N = xyz[:,0,:]
|
||||
CA = xyz[:,1,:]
|
||||
C = xyz[:,2,:]
|
||||
O = self.extend(np.roll(N, -1, axis=0), CA, C, 1.231, 2.108, -3.142)
|
||||
xyz = np.concatenate((xyz, O[:,None,:]), axis=1)
|
||||
self.write_pdb(seq[0], xyz, idx_pdb[0], Bfacts=lddt[0], prefix=out_prefix)
|
||||
|
||||
def extend(self, a,b,c, L,A,D):
|
||||
'''
|
||||
input: 3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral
|
||||
output: 4th coord
|
||||
'''
|
||||
N = lambda x: x/np.sqrt(np.square(x).sum(-1,keepdims=True) + 1e-8)
|
||||
bc = N(b-c)
|
||||
n = N(np.cross(b-a, bc))
|
||||
m = [bc,np.cross(n,bc),n]
|
||||
d = [L*np.cos(A), L*np.sin(A)*np.cos(D), -L*np.sin(A)*np.sin(D)]
|
||||
return c + sum([m*d for m,d in zip(m,d)])
|
||||
|
||||
def write_pdb(self, seq, atoms, idx, Bfacts=None, prefix=None):
|
||||
L = len(seq)
|
||||
filename = "%s.pdb"%prefix
|
||||
ctr = 1
|
||||
with open(filename, 'wt') as f:
|
||||
if Bfacts == None:
|
||||
Bfacts = np.zeros(L)
|
||||
else:
|
||||
Bfacts = torch.clamp( Bfacts, 0, 1)
|
||||
|
||||
for i,s in enumerate(seq):
|
||||
if (len(atoms.shape)==2):
|
||||
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
|
||||
"ATOM", ctr, " CA ", util.num2aa[s],
|
||||
"A", idx[i]+1, atoms[i,0], atoms[i,1], atoms[i,2],
|
||||
1.0, Bfacts[i] ) )
|
||||
ctr += 1
|
||||
|
||||
elif atoms.shape[1]==3:
|
||||
for j,atm_j in enumerate((" N "," CA "," C ")):
|
||||
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
|
||||
"ATOM", ctr, atm_j, util.num2aa[s],
|
||||
"A", idx[i]+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2],
|
||||
1.0, Bfacts[i] ) )
|
||||
ctr += 1
|
||||
|
||||
elif atoms.shape[1]==4:
|
||||
for j,atm_j in enumerate((" N "," CA "," C ", " O ")):
|
||||
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
|
||||
"ATOM", ctr, atm_j, util.num2aa[s],
|
||||
"A", idx[i]+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2],
|
||||
1.0, Bfacts[i] ) )
|
||||
ctr += 1
|
||||
|
||||
|
||||
def get_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", dest="model_dir", default="%s/weights"%(script_dir),
|
||||
help="Path to pre-trained network weights [%s/weights]"%script_dir)
|
||||
parser.add_argument("-i", dest="a3m_fn", required=True,
|
||||
help="Input multiple sequence alignments (in a3m format)")
|
||||
parser.add_argument("-o", dest="out_prefix", required=True,
|
||||
help="Prefix for output file. The output files will be [out_prefix].npz and [out_prefix].pdb")
|
||||
parser.add_argument("--hhr", default=None,
|
||||
help="HHsearch output file (hhr file). If not provided, zero matrices will be given as templates")
|
||||
parser.add_argument("--atab", default=None,
|
||||
help="HHsearch output file (atab file)")
|
||||
parser.add_argument("--db", default="%s/pdb100_2021Mar03/pdb100_2021Mar03"%script_dir,
|
||||
help="Path to template database [%s/pdb100_2021Mar03]"%script_dir)
|
||||
parser.add_argument("--cpu", dest='use_cpu', default=False, action='store_true')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
FFDB=args.db
|
||||
FFindexDB = namedtuple("FFindexDB", "index, data")
|
||||
ffdb = FFindexDB(read_index(FFDB+'_pdb.ffindex'),
|
||||
read_data(FFDB+'_pdb.ffdata'))
|
||||
|
||||
if not os.path.exists("%s.npz"%args.out_prefix):
|
||||
pred = Predictor(model_dir=args.model_dir, use_cpu=args.use_cpu)
|
||||
pred.predict(args.a3m_fn, args.out_prefix, args.hhr, args.atab)
|
|
@ -0,0 +1,200 @@
|
|||
import sys, os
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils import data
|
||||
from parsers import parse_a3m, read_templates
|
||||
from RoseTTAFoldModel import RoseTTAFoldModule
|
||||
import util
|
||||
from collections import namedtuple
|
||||
from ffindex import *
|
||||
from kinematics import xyz_to_c6d, c6d_to_bins2, xyz_to_t2d
|
||||
|
||||
script_dir = '/'.join(os.path.dirname(os.path.realpath(__file__)).split('/')[:-1])
|
||||
|
||||
NBIN = [37, 37, 37, 19]
|
||||
|
||||
MODEL_PARAM ={
|
||||
"n_module" : 8,
|
||||
"n_module_str" : 4,
|
||||
"n_layer" : 1,
|
||||
"d_msa" : 384 ,
|
||||
"d_pair" : 288,
|
||||
"d_templ" : 64,
|
||||
"n_head_msa" : 12,
|
||||
"n_head_pair" : 8,
|
||||
"n_head_templ" : 4,
|
||||
"d_hidden" : 64,
|
||||
"r_ff" : 4,
|
||||
"n_resblock" : 1,
|
||||
"p_drop" : 0.1,
|
||||
"use_templ" : True,
|
||||
"performer_N_opts": {"nb_features": 64},
|
||||
"performer_L_opts": {"nb_features": 64}
|
||||
}
|
||||
|
||||
SE3_param = {
|
||||
"num_layers" : 2,
|
||||
"num_channels" : 16,
|
||||
"num_degrees" : 2,
|
||||
"l0_in_features": 32,
|
||||
"l0_out_features": 8,
|
||||
"l1_in_features": 3,
|
||||
"l1_out_features": 3,
|
||||
"num_edge_features": 32,
|
||||
"div": 2,
|
||||
"n_heads": 4
|
||||
}
|
||||
MODEL_PARAM['SE3_param'] = SE3_param
|
||||
|
||||
class Predictor():
|
||||
def __init__(self, model_dir=None, use_cpu=False):
|
||||
if model_dir == None:
|
||||
self.model_dir = "%s/models"%(os.path.dirname(os.path.realpath(__file__)))
|
||||
else:
|
||||
self.model_dir = model_dir
|
||||
#
|
||||
# define model name
|
||||
self.model_name = "RoseTTAFold"
|
||||
if torch.cuda.is_available() and (not use_cpu):
|
||||
self.device = torch.device("cuda")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
self.active_fn = nn.Softmax(dim=1)
|
||||
|
||||
# define model & load model
|
||||
self.model = RoseTTAFoldModule(**MODEL_PARAM).to(self.device)
|
||||
could_load = self.load_model(self.model_name)
|
||||
if not could_load:
|
||||
print ("ERROR: failed to load model")
|
||||
sys.exit()
|
||||
|
||||
def load_model(self, model_name, suffix='pyrosetta'):
|
||||
chk_fn = "%s/%s_%s.pt"%(self.model_dir, model_name, suffix)
|
||||
if not os.path.exists(chk_fn):
|
||||
return False
|
||||
checkpoint = torch.load(chk_fn, map_location=self.device)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
||||
return True
|
||||
|
||||
def predict(self, a3m_fn, out_prefix, hhr_fn=None, atab_fn=None, window=150, shift=50):
|
||||
msa = parse_a3m(a3m_fn)
|
||||
N, L = msa.shape
|
||||
#
|
||||
if hhr_fn != None:
|
||||
xyz_t, t1d, t0d = read_templates(L, ffdb, hhr_fn, atab_fn, n_templ=25)
|
||||
else:
|
||||
xyz_t = torch.full((1, L, 3, 3), np.nan).float()
|
||||
t1d = torch.zeros((1, L, 3)).float()
|
||||
t0d = torch.zeros((1,3)).float()
|
||||
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
#
|
||||
msa = torch.tensor(msa).long().view(1, -1, L)
|
||||
idx_pdb = torch.arange(L).long().view(1, L)
|
||||
seq = msa[:,0]
|
||||
#
|
||||
# template features
|
||||
xyz_t = xyz_t.float().unsqueeze(0)
|
||||
t1d = t1d.float().unsqueeze(0)
|
||||
t0d = t0d.float().unsqueeze(0)
|
||||
t2d = xyz_to_t2d(xyz_t, t0d)
|
||||
#
|
||||
# do cropped prediction
|
||||
if L > window*2:
|
||||
prob_s = [np.zeros((L,L,NBIN[i]), dtype=np.float32) for i in range(4)]
|
||||
count = np.zeros((L,L), dtype=np.float32)
|
||||
#
|
||||
grids = np.arange(0, L-window+shift, shift)
|
||||
ngrids = grids.shape[0]
|
||||
print("ngrid: ", ngrids)
|
||||
print("grids: ", grids)
|
||||
print("windows: ", window)
|
||||
|
||||
for i in range(ngrids):
|
||||
for j in range(i, ngrids):
|
||||
start_1 = grids[i]
|
||||
end_1 = min(grids[i]+window, L)
|
||||
start_2 = grids[j]
|
||||
end_2 = min(grids[j]+window, L)
|
||||
sel = np.zeros((L)).astype(np.bool)
|
||||
sel[start_1:end_1] = True
|
||||
sel[start_2:end_2] = True
|
||||
|
||||
input_msa = msa[:,:,sel]
|
||||
mask = torch.sum(input_msa==20, dim=-1) < 0.5*sel.sum() # remove too gappy sequences
|
||||
input_msa = input_msa[mask].unsqueeze(0)
|
||||
input_msa = input_msa[:,:1000].to(self.device)
|
||||
input_idx = idx_pdb[:,sel].to(self.device)
|
||||
input_seq = input_msa[:,0].to(self.device)
|
||||
#
|
||||
input_t1d = t1d[:,:,sel].to(self.device)
|
||||
input_t2d = t2d[:,:,sel][:,:,:,sel].to(self.device)
|
||||
#
|
||||
print ("running crop: %d-%d/%d-%d"%(start_1, end_1, start_2, end_2), input_msa.shape)
|
||||
with torch.cuda.amp.autocast():
|
||||
logit_s, init_crds, pred_lddt = self.model(input_msa, input_seq, input_idx, t1d=input_t1d, t2d=input_t2d)
|
||||
#
|
||||
pred_lddt = torch.clamp(pred_lddt, 0.0, 1.0)
|
||||
weight = pred_lddt[0][:,None] + pred_lddt[0][None,:]
|
||||
weight = weight.cpu().numpy() + 1e-8
|
||||
sub_idx = input_idx[0].cpu()
|
||||
sub_idx_2d = np.ix_(sub_idx, sub_idx)
|
||||
count[sub_idx_2d] += weight
|
||||
for i_logit, logit in enumerate(logit_s):
|
||||
prob = self.active_fn(logit.float()) # calculate distogram
|
||||
prob = prob.squeeze(0).permute(1,2,0).cpu().numpy()
|
||||
prob_s[i_logit][sub_idx_2d] += weight[:,:,None]*prob
|
||||
for i in range(4):
|
||||
prob_s[i] = prob_s[i] / count[:,:,None]
|
||||
else:
|
||||
msa = msa[:,:1000].to(self.device)
|
||||
seq = msa[:,0]
|
||||
idx_pdb = idx_pdb.to(self.device)
|
||||
t1d = t1d.to(self.device)
|
||||
t2d = t2d.to(self.device)
|
||||
logit_s, init_crds, pred_lddt = self.model(msa, seq, idx_pdb, t1d=t1d, t2d=t2d)
|
||||
prob_s = list()
|
||||
for logit in logit_s:
|
||||
prob = self.active_fn(logit.float()) # distogram
|
||||
prob = prob.reshape(-1, L, L).permute(1,2,0).cpu().numpy()
|
||||
prob_s.append(prob)
|
||||
|
||||
np.savez_compressed("%s.npz"%out_prefix, dist=prob_s[0].astype(np.float16), \
|
||||
omega=prob_s[1].astype(np.float16),\
|
||||
theta=prob_s[2].astype(np.float16),\
|
||||
phi=prob_s[3].astype(np.float16))
|
||||
|
||||
|
||||
def get_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", dest="model_dir", default="%s/weights"%(script_dir),
|
||||
help="Path to pre-trained network weights [%s/weights]"%script_dir)
|
||||
parser.add_argument("-i", dest="a3m_fn", required=True,
|
||||
help="Input multiple sequence alignments (in a3m format)")
|
||||
parser.add_argument("-o", dest="out_prefix", required=True,
|
||||
help="Prefix for output file. The output file will be [out_prefix].npz")
|
||||
parser.add_argument("--hhr", default=None,
|
||||
help="HHsearch output file (hhr file). If not provided, zero matrices will be given as templates")
|
||||
parser.add_argument("--atab", default=None,
|
||||
help="HHsearch output file (atab file)")
|
||||
parser.add_argument("--db", default="%s/pdb100_2021Mar03/pdb100_2021Mar03"%script_dir,
|
||||
help="Path to template database [%s/pdb100_2021Mar03]"%script_dir)
|
||||
parser.add_argument("--cpu", dest='use_cpu', default=False, action='store_true')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
FFDB=args.db
|
||||
FFindexDB = namedtuple("FFindexDB", "index, data")
|
||||
ffdb = FFindexDB(read_index(FFDB+'_pdb.ffindex'),
|
||||
read_data(FFDB+'_pdb.ffdata'))
|
||||
|
||||
if not os.path.exists("%s.npz"%args.out_prefix):
|
||||
pred = Predictor(model_dir=args.model_dir, use_cpu=args.use_cpu)
|
||||
pred.predict(args.a3m_fn, args.out_prefix, args.hhr, args.atab)
|
95
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/resnet.py
Normal file
95
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/resnet.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# original resblock
|
||||
class ResBlock2D(nn.Module):
|
||||
def __init__(self, n_c, kernel=3, dilation=1, p_drop=0.15):
|
||||
super(ResBlock2D, self).__init__()
|
||||
padding = self._get_same_padding(kernel, dilation)
|
||||
|
||||
layer_s = list()
|
||||
layer_s.append(nn.Conv2d(n_c, n_c, kernel, padding=padding, dilation=dilation, bias=False))
|
||||
layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6))
|
||||
layer_s.append(nn.ELU(inplace=True))
|
||||
# dropout
|
||||
layer_s.append(nn.Dropout(p_drop))
|
||||
# convolution
|
||||
layer_s.append(nn.Conv2d(n_c, n_c, kernel, dilation=dilation, padding=padding, bias=False))
|
||||
layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6))
|
||||
self.layer = nn.Sequential(*layer_s)
|
||||
self.final_activation = nn.ELU(inplace=True)
|
||||
|
||||
def _get_same_padding(self, kernel, dilation):
|
||||
return (kernel + (kernel - 1) * (dilation - 1) - 1) // 2
|
||||
|
||||
def forward(self, x):
|
||||
out = self.layer(x)
|
||||
return self.final_activation(x + out)
|
||||
|
||||
# pre-activation bottleneck resblock
|
||||
class ResBlock2D_bottleneck(nn.Module):
|
||||
def __init__(self, n_c, kernel=3, dilation=1, p_drop=0.15):
|
||||
super(ResBlock2D_bottleneck, self).__init__()
|
||||
padding = self._get_same_padding(kernel, dilation)
|
||||
|
||||
n_b = n_c // 2 # bottleneck channel
|
||||
|
||||
layer_s = list()
|
||||
# pre-activation
|
||||
layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6))
|
||||
layer_s.append(nn.ELU(inplace=True))
|
||||
# project down to n_b
|
||||
layer_s.append(nn.Conv2d(n_c, n_b, 1, bias=False))
|
||||
layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6))
|
||||
layer_s.append(nn.ELU(inplace=True))
|
||||
# convolution
|
||||
layer_s.append(nn.Conv2d(n_b, n_b, kernel, dilation=dilation, padding=padding, bias=False))
|
||||
layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6))
|
||||
layer_s.append(nn.ELU(inplace=True))
|
||||
# dropout
|
||||
layer_s.append(nn.Dropout(p_drop))
|
||||
# project up
|
||||
layer_s.append(nn.Conv2d(n_b, n_c, 1, bias=False))
|
||||
|
||||
self.layer = nn.Sequential(*layer_s)
|
||||
|
||||
def _get_same_padding(self, kernel, dilation):
|
||||
return (kernel + (kernel - 1) * (dilation - 1) - 1) // 2
|
||||
|
||||
def forward(self, x):
|
||||
out = self.layer(x)
|
||||
return x + out
|
||||
|
||||
class ResidualNetwork(nn.Module):
|
||||
def __init__(self, n_block, n_feat_in, n_feat_block, n_feat_out,
|
||||
dilation=[1,2,4,8], block_type='orig', p_drop=0.15):
|
||||
super(ResidualNetwork, self).__init__()
|
||||
|
||||
|
||||
layer_s = list()
|
||||
# project to n_feat_block
|
||||
if n_feat_in != n_feat_block:
|
||||
layer_s.append(nn.Conv2d(n_feat_in, n_feat_block, 1, bias=False))
|
||||
if block_type =='orig': # should acitivate input
|
||||
layer_s.append(nn.InstanceNorm2d(n_feat_block, affine=True, eps=1e-6))
|
||||
layer_s.append(nn.ELU(inplace=True))
|
||||
|
||||
# add resblocks
|
||||
for i_block in range(n_block):
|
||||
d = dilation[i_block%len(dilation)]
|
||||
if block_type == 'orig':
|
||||
res_block = ResBlock2D(n_feat_block, kernel=3, dilation=d, p_drop=p_drop)
|
||||
else:
|
||||
res_block = ResBlock2D_bottleneck(n_feat_block, kernel=3, dilation=d, p_drop=p_drop)
|
||||
layer_s.append(res_block)
|
||||
|
||||
if n_feat_out != n_feat_block:
|
||||
# project to n_feat_out
|
||||
layer_s.append(nn.Conv2d(n_feat_block, n_feat_out, 1))
|
||||
|
||||
self.layer = nn.Sequential(*layer_s)
|
||||
|
||||
def forward(self, x):
|
||||
output = self.layer(x)
|
||||
return output
|
||||
|
252
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/trFold.py
Normal file
252
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/trFold.py
Normal file
|
@ -0,0 +1,252 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
def perturb_init(xyz, batch, noise=0.5):
|
||||
L = xyz.shape[0]
|
||||
pert = torch.tensor(np.random.uniform(noise, size=(batch, L, 3)), device=xyz.device)
|
||||
|
||||
xyz = xyz.unsqueeze(0) + pert.detach()
|
||||
return xyz
|
||||
|
||||
def Q2R(Q):
|
||||
'''convert quaternions to rotation matrices'''
|
||||
b,l,_ = Q.shape
|
||||
w,x,y,z = Q[...,0],Q[...,1],Q[...,2],Q[...,3]
|
||||
xx,xy,xz,xw = x*x, x*y, x*z, x*w
|
||||
yy,yz,yw = y*y, y*z, y*w
|
||||
zz,zw = z*z, z*w
|
||||
R = torch.stack([1-2*yy-2*zz, 2*xy-2*zw, 2*xz+2*yw,
|
||||
2*xy+2*zw, 1-2*xx-2*zz, 2*yz-2*xw,
|
||||
2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy],dim=-1).view(b,l,3,3)
|
||||
return R
|
||||
|
||||
def get_cb(N,Ca,C):
|
||||
"""recreate Cb given N,Ca,C"""
|
||||
b = Ca - N
|
||||
c = C - Ca
|
||||
a = torch.cross(b, c, dim=-1)
|
||||
Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
|
||||
return Cb
|
||||
|
||||
# ============================================================
|
||||
def get_ang(a, b, c):
|
||||
"""calculate planar angles for all consecutive triples (a[i],b[i],c[i])
|
||||
from Cartesian coordinates of three sets of atoms a,b,c
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a,b,c : pytorch tensors of shape [batch,nres,3]
|
||||
store Cartesian coordinates of three sets of atoms
|
||||
Returns
|
||||
-------
|
||||
ang : pytorch tensor of shape [batch,nres]
|
||||
stores resulting planar angles
|
||||
"""
|
||||
v = a - b
|
||||
w = c - b
|
||||
v = v / torch.norm(v, dim=-1, keepdim=True)
|
||||
w = w / torch.norm(w, dim=-1, keepdim=True)
|
||||
vw = torch.sum(v*w, dim=-1)
|
||||
|
||||
return torch.acos(vw)
|
||||
|
||||
# ============================================================
|
||||
def get_dih(a, b, c, d):
|
||||
"""calculate dihedral angles for all consecutive quadruples (a[i],b[i],c[i],d[i])
|
||||
given Cartesian coordinates of four sets of atoms a,b,c,d
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a,b,c,d : pytorch tensors of shape [batch,nres,3]
|
||||
store Cartesian coordinates of four sets of atoms
|
||||
Returns
|
||||
-------
|
||||
dih : pytorch tensor of shape [batch,nres]
|
||||
stores resulting dihedrals
|
||||
"""
|
||||
b0 = a - b
|
||||
b1 = c - b
|
||||
b2 = d - c
|
||||
|
||||
b1 = b1 / torch.norm(b1, dim=-1, keepdim=True)
|
||||
|
||||
v = b0 - torch.sum(b0*b1, dim=-1, keepdim=True)*b1
|
||||
w = b2 - torch.sum(b2*b1, dim=-1, keepdim=True)*b1
|
||||
|
||||
x = torch.sum(v*w, dim=-1)
|
||||
y = torch.sum(torch.cross(b1,v,dim=-1)*w, dim=-1)
|
||||
|
||||
return torch.atan2(y, x)
|
||||
|
||||
|
||||
class TRFold():
|
||||
|
||||
def __init__(self, pred, params):
|
||||
|
||||
self.pred = pred
|
||||
self.params = params
|
||||
self.device = self.pred[0].device
|
||||
|
||||
# dfire background correction for distograms
|
||||
self.bkgd = (torch.linspace(4.25,19.75,32,device=self.device)/
|
||||
self.params['DCUT'])**self.params['ALPHA']
|
||||
|
||||
# background correction for phi
|
||||
ang = torch.linspace(0,np.pi,19,device=self.device)[:-1]
|
||||
self.bkgp = 0.5*(torch.cos(ang)-torch.cos(ang+np.deg2rad(10.0)))
|
||||
|
||||
# Sav-Gol filter
|
||||
self.sg = torch.from_numpy(self.params['SG']).float().to(self.device)
|
||||
|
||||
# paddings for distograms:
|
||||
# left - linear clash; right - zeroes
|
||||
padRsize = self.sg.shape[-1]//2+3
|
||||
padLsize = padRsize + 8
|
||||
padR = torch.zeros(padRsize,device=self.device)
|
||||
padL = torch.arange(1,padLsize+1,device=self.device).flip(0)*self.params['CLASH']
|
||||
self.padR = padR[:,None]
|
||||
self.padL = padL[:,None]
|
||||
|
||||
# backbone motif
|
||||
self.ncac = torch.from_numpy(self.params['NCAC']).to(self.device)
|
||||
|
||||
def akima(self, y,h):
|
||||
''' Akima spline coefficients (boundaries trimmed to [2:-2])
|
||||
https://doi.org/10.1145/321607.321609 '''
|
||||
m = (y[:,1:]-y[:,:-1])/h
|
||||
#m += 1e-3*torch.randn(m.shape, device=m.device)
|
||||
m4m3 = torch.abs(m[:,3:]-m[:,2:-1])
|
||||
m2m1 = torch.abs(m[:,1:-2]-m[:,:-3])
|
||||
t = (m4m3*m[:,1:-2] + m2m1*m[:,2:-1])/(m4m3+m2m1)
|
||||
t[torch.isnan(t)] = 0.0
|
||||
dy = y[:,3:-2]-y[:,2:-3]
|
||||
coef = torch.stack([y[:,2:-3],
|
||||
t[:,:-1],
|
||||
(3*dy/h - 2*t[:,:-1] - t[:,1:])/h,
|
||||
(t[:,:-1]+t[:,1:] - 2*dy/h)/h**2], dim=-1)
|
||||
return coef
|
||||
|
||||
def fold(self, xyz, batch=32, lr=0.8, nsteps=100):
|
||||
|
||||
pd,po,pt,pp = self.pred
|
||||
L = pd.shape[-1]
|
||||
|
||||
p20 = (6.0-pd[-1]-po[-1]-pt[-1]-pp[-1]-(pt[-1]+pp[-1]).T)/6
|
||||
i,j = torch.triu_indices(L,L,1,device=self.device)
|
||||
sel = torch.where(p20[i,j]>self.params['PCUT'])[0]
|
||||
|
||||
# indices for dist and omega (symmetric)
|
||||
i_s,j_s = i[sel], j[sel]
|
||||
|
||||
# indices for theta and phi (asymmetric)
|
||||
i_a,j_a = torch.hstack([i_s,j_s]), torch.hstack([j_s,i_s])
|
||||
|
||||
# background-corrected initial restraints
|
||||
cstd = -torch.log(pd[4:36,i_s,j_s]/self.bkgd[:,None])
|
||||
csto = -torch.log(po[0:36,i_s,j_s]/(1./36)) # omega and theta
|
||||
cstt = -torch.log(pt[0:36,i_a,j_a]/(1./36)) # are almost uniform
|
||||
cstp = -torch.log(pp[0:18,i_a,j_a]/self.bkgp[:,None])
|
||||
|
||||
# padded restraints
|
||||
pad = self.sg.shape[-1]//2+3
|
||||
cstd = torch.cat([self.padL+cstd[0],cstd,self.padR+cstd[-1]],dim=0)
|
||||
csto = torch.cat([csto[-pad:],csto,csto[:pad]],dim=0)
|
||||
cstt = torch.cat([cstt[-pad:],cstt,cstt[:pad]],dim=0)
|
||||
cstp = torch.cat([cstp[:pad].flip(0),cstp,cstp[-pad:].flip(0)],dim=0)
|
||||
|
||||
# smoothed restraints
|
||||
cstd,csto,cstt,cstp = [nn.functional.conv1d(cst.T.unsqueeze(1),self.sg)[:,0]
|
||||
for cst in [cstd,csto,cstt,cstp]]
|
||||
|
||||
# force distance restraints vanish at long distances
|
||||
cstd = cstd-cstd[:,-1][:,None]
|
||||
|
||||
# akima spline coefficients
|
||||
coefd = self.akima(cstd, self.params['DSTEP']).detach()
|
||||
coefo = self.akima(csto, self.params['ASTEP']).detach()
|
||||
coeft = self.akima(cstt, self.params['ASTEP']).detach()
|
||||
coefp = self.akima(cstp, self.params['ASTEP']).detach()
|
||||
|
||||
astep = self.params['ASTEP']
|
||||
|
||||
ko = torch.arange(i_s.shape[0],device=self.device).repeat(batch)
|
||||
kt = torch.arange(i_a.shape[0],device=self.device).repeat(batch)
|
||||
|
||||
# initial Ca placement using EDM+minimization
|
||||
xyz = perturb_init(xyz, batch) # (batch, L, 3)
|
||||
|
||||
# optimization variables: T - shift vectors, Q - rotation quaternions
|
||||
T = torch.zeros_like(xyz,device=self.device,requires_grad=True)
|
||||
Q = torch.randn([batch,L,4],device=self.device,requires_grad=True)
|
||||
bb0 = self.ncac[None,:,None,:].repeat(batch,1,L,1)
|
||||
|
||||
opt = torch.optim.Adam([T,Q], lr=lr)
|
||||
for step in range(nsteps):
|
||||
|
||||
|
||||
R = Q2R(Q/torch.norm(Q,dim=-1,keepdim=True))
|
||||
bb = torch.einsum("blij,bklj->bkli",R,bb0)+(xyz+T)[:,None]
|
||||
|
||||
# TODO: include Cb in the motif
|
||||
N,Ca,C = bb[:,0],bb[:,1],bb[:,2]
|
||||
Cb = get_cb(N,Ca,C)
|
||||
|
||||
o = get_dih(Ca[:,i_s],Cb[:,i_s],Cb[:,j_s],Ca[:,j_s]) + np.pi
|
||||
t = get_dih(N[:,i_a],Ca[:,i_a],Cb[:,i_a],Cb[:,j_a]) + np.pi
|
||||
p = get_ang(Ca[:,i_a],Cb[:,i_a],Cb[:,j_a])
|
||||
|
||||
dij = torch.norm(Cb[:,i_s]-Cb[:,j_s],dim=-1)
|
||||
b,k = torch.where(dij<20.0)
|
||||
dk = dij[b,k]
|
||||
|
||||
#coord = [coord/step-0.5 for coord,step in zip([dij,o,t,p],[dstep,astep,astep,astep])]
|
||||
#bins = [torch.ceil(c).long() for c in coord]
|
||||
#delta = [torch.frac(c) for c in coord]
|
||||
|
||||
kbin = torch.ceil((dk-0.25)/0.5).long()
|
||||
dx = (dk-0.25)%0.5
|
||||
c = coefd[k,kbin]
|
||||
lossd = c[:,0]+c[:,1]*dx+c[:,2]*dx**2+c[:,3]*dx**3
|
||||
|
||||
# omega
|
||||
obin = torch.ceil((o.view(-1)-astep/2)/astep).long()
|
||||
do = (o.view(-1)-astep/2)%astep
|
||||
co = coefo[ko,obin]
|
||||
losso = (co[:,0]+co[:,1]*do+co[:,2]*do**2+co[:,3]*do**3).view(batch,-1) #.mean(1)
|
||||
|
||||
# theta
|
||||
tbin = torch.ceil((t.view(-1)-astep/2)/astep).long()
|
||||
dt = (t.view(-1)-astep/2)%astep
|
||||
ct = coeft[kt,tbin]
|
||||
losst = (ct[:,0]+ct[:,1]*dt+ct[:,2]*dt**2+ct[:,3]*dt**3).view(batch,-1) #.mean(1)
|
||||
|
||||
# phi
|
||||
pbin = torch.ceil((p.view(-1)-astep/2)/astep).long()
|
||||
dp = (p.view(-1)-astep/2)%astep
|
||||
cp = coefp[kt,pbin]
|
||||
lossp = (cp[:,0]+cp[:,1]*dp+cp[:,2]*dp**2+cp[:,3]*dp**3).view(batch,-1) #.mean(1)
|
||||
|
||||
# restrain geometry of peptide bonds
|
||||
loss_nc = (torch.norm(C[:,:-1]-N[:,1:],dim=-1)-1.32868)**2
|
||||
loss_cacn = (get_ang(Ca[:,:-1], C[:,:-1], N[:,1:]) - 2.02807)**2
|
||||
loss_canc = (get_ang(Ca[:,1:], N[:,1:], C[:,:-1]) - 2.12407)**2
|
||||
|
||||
loss_geom = loss_nc.mean(1) + loss_cacn.mean(1) + loss_canc.mean(1)
|
||||
loss_ang = losst.mean(1) + losso.mean(1) + lossp.mean(1)
|
||||
|
||||
# coefficient for ramping up geometric restraints during minimization
|
||||
coef = (1.0+step)/nsteps
|
||||
|
||||
loss = lossd.mean() + self.params['WANG']*loss_ang.mean() + coef*self.params['WCST']*loss_geom.mean()
|
||||
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
lossd = torch.stack([lossd[b==i].mean() for i in range(batch)])
|
||||
loss = lossd + self.params['WANG']*loss_ang + self.params['WCST']*loss_geom
|
||||
minidx = torch.argmin(loss)
|
||||
|
||||
return bb[minidx].permute(1,0,2)
|
||||
|
253
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/util.py
Normal file
253
DGLPyTorch/DrugDiscovery/RoseTTAFold/network/util.py
Normal file
|
@ -0,0 +1,253 @@
|
|||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
num2aa=[
|
||||
'ALA','ARG','ASN','ASP','CYS',
|
||||
'GLN','GLU','GLY','HIS','ILE',
|
||||
'LEU','LYS','MET','PHE','PRO',
|
||||
'SER','THR','TRP','TYR','VAL',
|
||||
]
|
||||
|
||||
aa2num= {x:i for i,x in enumerate(num2aa)}
|
||||
|
||||
# minimal sc atom representation (Nx8)
|
||||
aa2short=[
|
||||
(" N "," CA "," C "," CB ", None, None, None, None), # ala
|
||||
(" N "," CA "," C "," CB "," CG "," CD "," NE "," CZ "), # arg
|
||||
(" N "," CA "," C "," CB "," CG "," OD1", None, None), # asn
|
||||
(" N "," CA "," C "," CB "," CG "," OD1", None, None), # asp
|
||||
(" N "," CA "," C "," CB "," SG ", None, None, None), # cys
|
||||
(" N "," CA "," C "," CB "," CG "," CD "," OE1", None), # gln
|
||||
(" N "," CA "," C "," CB "," CG "," CD "," OE1", None), # glu
|
||||
(" N "," CA "," C ", None, None, None, None, None), # gly
|
||||
(" N "," CA "," C "," CB "," CG "," ND1", None, None), # his
|
||||
(" N "," CA "," C "," CB "," CG1"," CD1", None, None), # ile
|
||||
(" N "," CA "," C "," CB "," CG "," CD1", None, None), # leu
|
||||
(" N "," CA "," C "," CB "," CG "," CD "," CE "," NZ "), # lys
|
||||
(" N "," CA "," C "," CB "," CG "," SD "," CE ", None), # met
|
||||
(" N "," CA "," C "," CB "," CG "," CD1", None, None), # phe
|
||||
(" N "," CA "," C "," CB "," CG "," CD ", None, None), # pro
|
||||
(" N "," CA "," C "," CB "," OG ", None, None, None), # ser
|
||||
(" N "," CA "," C "," CB "," OG1", None, None, None), # thr
|
||||
(" N "," CA "," C "," CB "," CG "," CD1", None, None), # trp
|
||||
(" N "," CA "," C "," CB "," CG "," CD1", None, None), # tyr
|
||||
(" N "," CA "," C "," CB "," CG1", None, None, None), # val
|
||||
]
|
||||
|
||||
# full sc atom representation (Nx14)
|
||||
aa2long=[
|
||||
(" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # ala
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD "," NE "," CZ "," NH1"," NH2", None, None, None), # arg
|
||||
(" N "," CA "," C "," O "," CB "," CG "," OD1"," ND2", None, None, None, None, None, None), # asn
|
||||
(" N "," CA "," C "," O "," CB "," CG "," OD1"," OD2", None, None, None, None, None, None), # asp
|
||||
(" N "," CA "," C "," O "," CB "," SG ", None, None, None, None, None, None, None, None), # cys
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," NE2", None, None, None, None, None), # gln
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," OE2", None, None, None, None, None), # glu
|
||||
(" N "," CA "," C "," O ", None, None, None, None, None, None, None, None, None, None), # gly
|
||||
(" N "," CA "," C "," O "," CB "," CG "," ND1"," CD2"," CE1"," NE2", None, None, None, None), # his
|
||||
(" N "," CA "," C "," O "," CB "," CG1"," CG2"," CD1", None, None, None, None, None, None), # ile
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2", None, None, None, None, None, None), # leu
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD "," CE "," NZ ", None, None, None, None, None), # lys
|
||||
(" N "," CA "," C "," O "," CB "," CG "," SD "," CE ", None, None, None, None, None, None), # met
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ ", None, None, None), # phe
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD ", None, None, None, None, None, None, None), # pro
|
||||
(" N "," CA "," C "," O "," CB "," OG ", None, None, None, None, None, None, None, None), # ser
|
||||
(" N "," CA "," C "," O "," CB "," OG1"," CG2", None, None, None, None, None, None, None), # thr
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE2"," CE3"," NE1"," CZ2"," CZ3"," CH2"), # trp
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ "," OH ", None, None), # tyr
|
||||
(" N "," CA "," C "," O "," CB "," CG1"," CG2", None, None, None, None, None, None, None), # val
|
||||
]
|
||||
|
||||
# build the "alternate" sc mapping
|
||||
aa2longalt=[
|
||||
(" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # ala
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD "," NE "," CZ "," NH2"," NH1", None, None, None), # arg
|
||||
(" N "," CA "," C "," O "," CB "," CG "," OD1"," ND2", None, None, None, None, None, None), # asn
|
||||
(" N "," CA "," C "," O "," CB "," CG "," OD2"," OD1", None, None, None, None, None, None), # asp
|
||||
(" N "," CA "," C "," O "," CB "," SG ", None, None, None, None, None, None, None, None), # cys
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," NE2", None, None, None, None, None), # gln
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD "," OE2"," OE1", None, None, None, None, None), # glu
|
||||
(" N "," CA "," C "," O ", None, None, None, None, None, None, None, None, None, None), # gly
|
||||
(" N "," CA "," C "," O "," CB "," CG "," ND1"," CD2"," CE1"," NE2", None, None, None, None), # his
|
||||
(" N "," CA "," C "," O "," CB "," CG1"," CG2"," CD1", None, None, None, None, None, None), # ile
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD2"," CD1", None, None, None, None, None, None), # leu
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD "," CE "," NZ ", None, None, None, None, None), # lys
|
||||
(" N "," CA "," C "," O "," CB "," CG "," SD "," CE ", None, None, None, None, None, None), # met
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD2"," CD1"," CE2"," CE1"," CZ ", None, None, None), # phe
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD ", None, None, None, None, None, None, None), # pro
|
||||
(" N "," CA "," C "," O "," CB "," OG ", None, None, None, None, None, None, None, None), # ser
|
||||
(" N "," CA "," C "," O "," CB "," OG1"," CG2", None, None, None, None, None, None, None), # thr
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE2"," CE3"," NE1"," CZ2"," CZ3"," CH2"), # trp
|
||||
(" N "," CA "," C "," O "," CB "," CG "," CD2"," CD1"," CE2"," CE1"," CZ "," OH ", None, None), # tyr
|
||||
(" N "," CA "," C "," O "," CB "," CG2"," CG1", None, None, None, None, None, None, None), # val
|
||||
]
|
||||
|
||||
|
||||
# build "deterministic" atoms
|
||||
# see notebook (se3_experiments.ipynb for derivation)
|
||||
aa2frames=[
|
||||
[], # ala
|
||||
[ # arg
|
||||
[' NH1', ' CZ ', ' NE ', ' CD ', [-0.7218378782272339, 1.0856682062149048, -0.006118079647421837]],
|
||||
[' NH2', ' CZ ', ' NE ', ' CD ', [-0.6158039569854736, -1.1400136947631836, 0.006467342376708984]]],
|
||||
[ # asn
|
||||
[' ND2', ' CG ', ' CB ', ' OD1', [-0.6304131746292114, -1.1431225538253784, 0.02364802360534668]]],
|
||||
[ # asp
|
||||
[' OD2', ' CG ', ' CB ', ' OD1', [-0.5972501039505005, -1.0955055952072144, 0.04530305415391922]]],
|
||||
[], # cys
|
||||
[ # gln
|
||||
[' NE2', ' CD ', ' CG ', ' OE1', [-0.6558755040168762, -1.1324536800384521, 0.026521772146224976]]],
|
||||
[ # glu
|
||||
[' OE2', ' CD ', ' CG ', ' OE1', [-0.5578438639640808, -1.1161314249038696, -0.015464287251234055]]],
|
||||
[], # gly
|
||||
[ # his
|
||||
[' CD2', ' CG ', ' CB ', ' ND1', [-0.7502505779266357, -1.1680538654327393, 0.0005368441343307495]],
|
||||
[' CE1', ' CG ', ' CB ', ' ND1', [-2.0262467861175537, 0.539483368396759, -0.004495501518249512]],
|
||||
[' NE2', ' CG ', ' CB ', ' ND1', [-2.0761325359344482, -0.8199722766876221, -0.0018703639507293701]]],
|
||||
[ # ile
|
||||
[' CG2', ' CB ', ' CA ', ' CG1', [-0.6059935688972473, -0.8108057379722595, 1.1861376762390137]]],
|
||||
[ # leu
|
||||
[' CD2', ' CG ', ' CB ', ' CD1', [-0.5942193269729614, -0.7693282961845398, -1.1914138793945312]]],
|
||||
[], # lys
|
||||
[], # met
|
||||
[ # phe
|
||||
[' CD2', ' CG ', ' CB ', ' CD1', [-0.7164441347122192, -1.197853446006775, 0.06416648626327515]],
|
||||
[' CE1', ' CG ', ' CB ', ' CD1', [-2.0785865783691406, 1.2366485595703125, 0.08100450038909912]],
|
||||
[' CE2', ' CG ', ' CB ', ' CD1', [-2.107091188430786, -1.178497076034546, 0.13524535298347473]],
|
||||
[' CZ ', ' CG ', ' CB ', ' CD1', [-2.786630630493164, 0.03873880207538605, 0.14633776247501373]]],
|
||||
[], # pro
|
||||
[], # ser
|
||||
[ # thr
|
||||
[' CG2', ' CB ', ' CA ', ' OG1', [-0.6842088103294373, -0.6709619164466858, 1.2105456590652466]]],
|
||||
[ # trp
|
||||
[' CD2', ' CG ', ' CB ', ' CD1', [-0.8550368547439575, -1.0790592432022095, 0.09017711877822876]],
|
||||
[' NE1', ' CG ', ' CB ', ' CD1', [-2.1863200664520264, 0.8064242601394653, 0.08350661396980286]],
|
||||
[' CE2', ' CG ', ' CB ', ' CD1', [-2.1801204681396484, -0.5795643329620361, 0.14015203714370728]],
|
||||
[' CE3', ' CG ', ' CB ', ' CD1', [-0.605582594871521, -2.4733362197875977, 0.16200461983680725]],
|
||||
[' CE2', ' CG ', ' CB ', ' CD1', [-2.1801204681396484, -0.5795643329620361, 0.14015203714370728]],
|
||||
[' CZ2', ' CG ', ' CB ', ' CD1', [-3.2672977447509766, -1.473116159439087, 0.250858873128891]],
|
||||
[' CZ3', ' CG ', ' CB ', ' CD1', [-1.6969941854476929, -3.3360071182250977, 0.264143705368042]],
|
||||
[' CH2', ' CG ', ' CB ', ' CD1', [-3.009331703186035, -2.8451972007751465, 0.3059283494949341]]],
|
||||
[ # tyr
|
||||
[' CD2', ' CG ', ' CB ', ' CD1', [-0.69439297914505, -1.2123756408691406, -0.009198814630508423]],
|
||||
[' CE1', ' CG ', ' CB ', ' CD1', [-2.104464054107666, 1.1910505294799805, -0.014679580926895142]],
|
||||
[' CE2', ' CG ', ' CB ', ' CD1', [-2.0857787132263184, -1.2231677770614624, -0.024517983198165894]],
|
||||
[' CZ ', ' CG ', ' CB ', ' CD1', [-2.7897322177886963, -0.021470561623573303, -0.026979409158229828]],
|
||||
[' OH ', ' CG ', ' CB ', ' CD1', [-4.1559271812438965, -0.029129385948181152, -0.044720835983753204]]],
|
||||
[ # val
|
||||
[' CG2', ' CB ', ' CA ', ' CG1', [-0.6258467435836792, -0.7654698491096497, -1.1894742250442505]]],
|
||||
]
|
||||
|
||||
# O from frame (C,N-1,CA)
|
||||
bb2oframe=[-0.5992066264152527, -1.0820008516311646, 0.0001476481556892395]
|
||||
|
||||
# build the mapping from indices in reduced representation to
|
||||
# indices in the full representation
|
||||
# N x 14 x 6 = <base-idx | parent-idx | gparent-idx | x | y | z >
|
||||
# base-idx < 0 ==> no atom
|
||||
# xyz = 0 ==> no mapping
|
||||
short2long = np.zeros((20,14,6))
|
||||
for i in range(20):
|
||||
i_s, i_l = aa2short[i],aa2long[i]
|
||||
for j,a in enumerate(i_l):
|
||||
# case 1: if no atom defined, blank
|
||||
if (a is None):
|
||||
short2long[i,j,0] = -1
|
||||
# case 2: atom is a base atom
|
||||
elif (a in i_s):
|
||||
short2long[i,j,0] = i_s.index(a)
|
||||
if (short2long[i,j,0] == 0):
|
||||
short2long[i,j,1] = 1
|
||||
short2long[i,j,2] = 2
|
||||
else:
|
||||
short2long[i,j,1] = 0
|
||||
if (short2long[i,j,0] == 1):
|
||||
short2long[i,j,2] = 2
|
||||
else:
|
||||
short2long[i,j,2] = 1
|
||||
# case 3: atom is ' O '
|
||||
elif (a == " O "):
|
||||
short2long[i,j,0] = 2
|
||||
short2long[i,j,1] = 0 #Nprev (will pre-roll N as nothing else needs it)
|
||||
short2long[i,j,2] = 1
|
||||
short2long[i,j,3:] = np.array(bb2oframe)
|
||||
# case 4: build this atom
|
||||
else:
|
||||
i_f = aa2frames[i]
|
||||
names = [f[0] for f in i_f]
|
||||
idx = names.index(a)
|
||||
short2long[i,j,0] = i_s.index(i_f[idx][1])
|
||||
short2long[i,j,1] = i_s.index(i_f[idx][2])
|
||||
short2long[i,j,2] = i_s.index(i_f[idx][3])
|
||||
short2long[i,j,3:] = np.array(i_f[idx][4])
|
||||
|
||||
# build the mapping from atoms in the full rep (Nx14) to the "alternate" rep
|
||||
long2alt = np.zeros((20,14))
|
||||
for i in range(20):
|
||||
i_l, i_lalt = aa2long[i], aa2longalt[i]
|
||||
for j,a in enumerate(i_l):
|
||||
if (a is None):
|
||||
long2alt[i,j] = j
|
||||
else:
|
||||
long2alt[i,j] = i_lalt.index(a)
|
||||
|
||||
def atoms_from_frames(base,parent,gparent,points):
|
||||
xs = parent-base
|
||||
|
||||
# handle parent==base
|
||||
mask = (torch.sum(torch.square(xs),dim=-1)==0)
|
||||
xs[mask,0] = 1.0
|
||||
xs = xs / torch.norm(xs, dim=-1)[:,None]
|
||||
|
||||
ys = gparent-base
|
||||
ys = ys - torch.sum(xs*ys,dim=-1)[:,None]*xs
|
||||
|
||||
# handle gparent==base
|
||||
mask = (torch.sum(torch.square(ys),dim=-1)==0)
|
||||
ys[mask,1] = 1.0
|
||||
|
||||
ys = ys / torch.norm(ys, dim=-1)[:,None]
|
||||
zs = torch.cross(xs,ys)
|
||||
q = torch.stack((xs,ys,zs),dim=2)
|
||||
|
||||
retval = base + torch.einsum('nij,nj->ni',q,points)
|
||||
|
||||
return retval
|
||||
#def atoms_from_frames(base,parent,gparent,points):
|
||||
# xs = parent-base
|
||||
# # handle parent=base
|
||||
# mask = (torch.sum(torch.square(xs), dim=-1) == 0)
|
||||
# xs[mask,0] = 1.0
|
||||
# xs = xs / torch.norm(xs, dim=-1)[:,None]
|
||||
#
|
||||
# ys = gparent-base
|
||||
# # handle gparent=base
|
||||
# mask = (torch.sum(torch.square(ys),dim=-1)==0)
|
||||
# ys[mask,1] = 1.0
|
||||
#
|
||||
# ys = ys - torch.sum(xs*ys,dim=-1)[:,None]*xs
|
||||
# ys = ys / torch.norm(ys, dim=-1)[:,None]
|
||||
# zs = torch.cross(xs,ys)
|
||||
# q = torch.stack((xs,ys,zs),dim=2)
|
||||
# #return base + q@points
|
||||
# return base + torch.einsum('nij,nj->ni',q,points)
|
||||
|
||||
# writepdb
|
||||
def writepdb(filename, atoms, bfacts, seq):
|
||||
f = open(filename,"w")
|
||||
|
||||
ctr = 1
|
||||
scpu = seq.cpu()
|
||||
atomscpu = atoms.cpu()
|
||||
Bfacts = torch.clamp( bfacts.cpu(), 0, 1)
|
||||
for i,s in enumerate(scpu):
|
||||
atms = aa2long[s]
|
||||
for j,atm_j in enumerate(atms):
|
||||
if (atm_j is not None):
|
||||
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
|
||||
"ATOM", ctr, atm_j, num2aa[s],
|
||||
"A", i+1, atomscpu[i,j,0], atomscpu[i,j,1], atomscpu[i,j,2],
|
||||
1.0, Bfacts[i] ) )
|
||||
ctr += 1
|
|
@ -0,0 +1,64 @@
|
|||
import warnings
|
||||
|
||||
import dgl
|
||||
import torch
|
||||
|
||||
|
||||
def to_np(x):
|
||||
return x.cpu().detach().numpy()
|
||||
|
||||
|
||||
class PickleGraph:
|
||||
"""Lightweight graph object for easy pickling. Does not support batched graphs."""
|
||||
|
||||
def __init__(self, G=None, desired_keys=None):
|
||||
self.ndata = dict()
|
||||
self.edata = dict()
|
||||
|
||||
if G is None:
|
||||
self.src = []
|
||||
self.dst = []
|
||||
else:
|
||||
if G.batch_size > 1:
|
||||
warnings.warn("Copying a batched graph to a PickleGraph is not supported. "
|
||||
"All node and edge data will be copied, but batching information will be lost.")
|
||||
|
||||
self.src, self.dst = (to_np(idx) for idx in G.all_edges())
|
||||
|
||||
for k in G.ndata:
|
||||
if desired_keys is None or k in desired_keys:
|
||||
self.ndata[k] = to_np(G.ndata[k])
|
||||
|
||||
for k in G.edata:
|
||||
if desired_keys is None or k in desired_keys:
|
||||
self.edata[k] = to_np(G.edata[k])
|
||||
|
||||
def all_edges(self):
|
||||
return self.src, self.dst
|
||||
|
||||
|
||||
def copy_dgl_graph(G):
|
||||
if G.batch_size == 1:
|
||||
src, dst = G.all_edges()
|
||||
G2 = dgl.DGLGraph((src, dst))
|
||||
for edge_key in list(G.edata.keys()):
|
||||
G2.edata[edge_key] = torch.clone(G.edata[edge_key])
|
||||
for node_key in list(G.ndata.keys()):
|
||||
G2.ndata[node_key] = torch.clone(G.ndata[node_key])
|
||||
return G2
|
||||
else:
|
||||
list_of_graphs = dgl.unbatch(G)
|
||||
list_of_copies = []
|
||||
|
||||
for batch_G in list_of_graphs:
|
||||
list_of_copies.append(copy_dgl_graph(batch_G))
|
||||
|
||||
return dgl.batch(list_of_copies)
|
||||
|
||||
|
||||
def update_relative_positions(G, *, relative_position_key='d', absolute_position_key='x'):
|
||||
"""For each directed edge in the graph, calculate the relative position of the destination node with respect
|
||||
to the source node. Write the relative positions to the graph as edge data."""
|
||||
src, dst = G.all_edges()
|
||||
absolute_positions = G.ndata[absolute_position_key]
|
||||
G.edata[relative_position_key] = absolute_positions[dst] - absolute_positions[src]
|
|
@ -0,0 +1,123 @@
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
import datetime
|
||||
import subprocess
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from utils.utils_data import to_np
|
||||
|
||||
|
||||
_global_log = {}
|
||||
|
||||
|
||||
def try_mkdir(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
# @profile
|
||||
def make_logdir(checkpoint_dir, run_name=None):
|
||||
if run_name is None:
|
||||
now = datetime.datetime.now().strftime("%Y_%m_%d_%H.%M.%S")
|
||||
else:
|
||||
assert type(run_name) == str
|
||||
now = run_name
|
||||
|
||||
log_dir = os.path.join(checkpoint_dir, now)
|
||||
try_mkdir(log_dir)
|
||||
return log_dir
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
"""
|
||||
count number of trainable parameters in module
|
||||
:param model: nn.Module instance
|
||||
:return: integer
|
||||
"""
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
n_params = sum([np.prod(p.size()) for p in model_parameters])
|
||||
return n_params
|
||||
|
||||
|
||||
def write_info_file(model, FLAGS, UNPARSED_ARGV, wandb_log_dir=None):
|
||||
time_str = time.strftime("%m%d_%H%M%S")
|
||||
filename_log = "info_" + time_str + ".txt"
|
||||
filename_git_diff = "git_diff_" + time_str + ".txt"
|
||||
|
||||
checkpoint_name = 'model'
|
||||
|
||||
if wandb_log_dir:
|
||||
log_dir = wandb_log_dir
|
||||
os.mkdir(os.path.join(log_dir, 'checkpoints'))
|
||||
checkpoint_path = os.path.join(log_dir, 'checkpoints', checkpoint_name)
|
||||
elif FLAGS.restore:
|
||||
# set restore path
|
||||
assert FLAGS.run_name is not None
|
||||
log_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.run_name)
|
||||
checkpoint_path = os.path.join(log_dir, 'checkpoints', checkpoint_name)
|
||||
else:
|
||||
# makes logdir with time stamp
|
||||
log_dir = make_logdir(FLAGS.checkpoint_dir, FLAGS.run_name)
|
||||
os.mkdir(os.path.join(log_dir, 'checkpoints'))
|
||||
os.mkdir(os.path.join(log_dir, 'point_clouds'))
|
||||
# os.mkdir(os.path.join(log_dir, 'train_log'))
|
||||
# os.mkdir(os.path.join(log_dir, 'test_log'))
|
||||
checkpoint_path = os.path.join(log_dir, 'checkpoints', checkpoint_name)
|
||||
|
||||
# writing arguments and git hash to info file
|
||||
file = open(os.path.join(log_dir, filename_log), "w")
|
||||
label = subprocess.check_output(["git", "describe", "--always"]).strip()
|
||||
file.write('latest git commit on this branch: ' + str(label) + '\n')
|
||||
file.write('\nFLAGS: \n')
|
||||
for key in sorted(vars(FLAGS)):
|
||||
file.write(key + ': ' + str(vars(FLAGS)[key]) + '\n')
|
||||
|
||||
# count number of parameters
|
||||
if hasattr(model, 'parameters'):
|
||||
file.write('\nNumber of Model Parameters: ' + str(count_parameters(model)) + '\n')
|
||||
if hasattr(model, 'enc'):
|
||||
file.write('\nNumber of Encoder Parameters: ' + str(
|
||||
count_parameters(model.enc)) + '\n')
|
||||
if hasattr(model, 'dec'):
|
||||
file.write('\nNumber of Decoder Parameters: ' + str(
|
||||
count_parameters(model.dec)) + '\n')
|
||||
|
||||
file.write('\nUNPARSED_ARGV:\n' + str(UNPARSED_ARGV))
|
||||
file.write('\n\nBASH COMMAND: \n')
|
||||
bash_command = 'python'
|
||||
for argument in sys.argv:
|
||||
bash_command += (' ' + argument)
|
||||
file.write(bash_command)
|
||||
file.close()
|
||||
|
||||
# write 'git diff' output into extra file
|
||||
subprocess.call(["git diff > " + os.path.join(log_dir, filename_git_diff)], shell=True)
|
||||
|
||||
return log_dir, checkpoint_path
|
||||
|
||||
|
||||
def log_gradient_norm(tensor, variable_name):
|
||||
if variable_name not in _global_log:
|
||||
_global_log[variable_name] = []
|
||||
|
||||
def log_gradient_norm_inner(gradient):
|
||||
gradient_norm = torch.norm(gradient, dim=-1)
|
||||
_global_log[variable_name].append(to_np(gradient_norm))
|
||||
|
||||
tensor.register_hook(log_gradient_norm_inner)
|
||||
|
||||
|
||||
def get_average(variable_name):
|
||||
if variable_name not in _global_log:
|
||||
return float('nan')
|
||||
elif _global_log[variable_name]:
|
||||
overall_tensor = np.concatenate(_global_log[variable_name])
|
||||
return np.mean(overall_tensor)
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def clear_data(variable_name):
|
||||
_global_log[variable_name] = []
|
|
@ -0,0 +1,5 @@
|
|||
try:
|
||||
profile
|
||||
except NameError:
|
||||
def profile(func):
|
||||
return func
|
File diff suppressed because one or more lines are too long
42
DGLPyTorch/DrugDiscovery/RoseTTAFold/pipeline_utils.py
Normal file
42
DGLPyTorch/DrugDiscovery/RoseTTAFold/pipeline_utils.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import os
|
||||
import py3Dmol
|
||||
|
||||
|
||||
def execute_pipeline(sequence):
|
||||
if os.path.isfile(sequence):
|
||||
with open(sequence, "r") as f:
|
||||
title = f.readlines()[0][2:]
|
||||
print(f"Running inference on {title}")
|
||||
os.system(f"bash run_inference_pipeline.sh {sequence}")
|
||||
else:
|
||||
try:
|
||||
with open("temp_input.fa", "w") as f:
|
||||
f.write(f"> {sequence[:8]}...\n")
|
||||
f.write(sequence.upper())
|
||||
print(f"Running inference on {sequence[:8]}...")
|
||||
os.system(f"bash run_inference_pipeline.sh temp_input.fa")
|
||||
except:
|
||||
print("Unable to run the pipeline.")
|
||||
raise
|
||||
|
||||
|
||||
def display_pdb(path_to_pdb):
|
||||
with open(path_to_pdb) as ifile:
|
||||
protein = "".join([x for x in ifile])
|
||||
view = py3Dmol.view(width=400, height=300)
|
||||
view.addModelsAsFrames(protein)
|
||||
view.setStyle({'model': -1}, {"cartoon": {'color': 'spectrum'}})
|
||||
view.zoomTo()
|
||||
view.show()
|
||||
|
||||
|
||||
def cleanup():
|
||||
os.system("rm t000*")
|
4
DGLPyTorch/DrugDiscovery/RoseTTAFold/requirements.txt
Normal file
4
DGLPyTorch/DrugDiscovery/RoseTTAFold/requirements.txt
Normal file
|
@ -0,0 +1,4 @@
|
|||
pandas==1.1.4
|
||||
scikit-learn==0.24
|
||||
packaging==21.0
|
||||
py3Dmol==1.7.0
|
78
DGLPyTorch/DrugDiscovery/RoseTTAFold/run_e2e_ver.sh
Normal file
78
DGLPyTorch/DrugDiscovery/RoseTTAFold/run_e2e_ver.sh
Normal file
|
@ -0,0 +1,78 @@
|
|||
#!/bin/bash
|
||||
|
||||
# make the script stop when error (non-true exit code) is occured
|
||||
set -e
|
||||
|
||||
############################################################
|
||||
# >>> conda initialize >>>
|
||||
# !! Contents within this block are managed by 'conda init' !!
|
||||
__conda_setup="$('conda' 'shell.bash' 'hook' 2> /dev/null)"
|
||||
eval "$__conda_setup"
|
||||
unset __conda_setup
|
||||
# <<< conda initialize <<<
|
||||
############################################################
|
||||
|
||||
SCRIPT=`realpath -s $0`
|
||||
export PIPEDIR=`dirname $SCRIPT`
|
||||
|
||||
CPU="8" # number of CPUs to use
|
||||
MEM="64" # max memory (in GB)
|
||||
|
||||
# Inputs:
|
||||
IN="$1" # input.fasta
|
||||
WDIR=`realpath -s $2` # working folder
|
||||
|
||||
|
||||
LEN=`tail -n1 $IN | wc -m`
|
||||
|
||||
mkdir -p $WDIR/log
|
||||
|
||||
conda activate RoseTTAFold
|
||||
############################################################
|
||||
# 1. generate MSAs
|
||||
############################################################
|
||||
if [ ! -s $WDIR/t000_.msa0.a3m ]
|
||||
then
|
||||
echo "Running HHblits"
|
||||
$PIPEDIR/input_prep/make_msa.sh $IN $WDIR $CPU $MEM > $WDIR/log/make_msa.stdout 2> $WDIR/log/make_msa.stderr
|
||||
fi
|
||||
|
||||
|
||||
############################################################
|
||||
# 2. predict secondary structure for HHsearch run
|
||||
############################################################
|
||||
if [ ! -s $WDIR/t000_.ss2 ]
|
||||
then
|
||||
echo "Running PSIPRED"
|
||||
$PIPEDIR/input_prep/make_ss.sh $WDIR/t000_.msa0.a3m $WDIR/t000_.ss2 > $WDIR/log/make_ss.stdout 2> $WDIR/log/make_ss.stderr
|
||||
fi
|
||||
|
||||
|
||||
############################################################
|
||||
# 3. search for templates
|
||||
############################################################
|
||||
DB="$PIPEDIR/pdb100_2021Mar03/pdb100_2021Mar03"
|
||||
if [ ! -s $WDIR/t000_.hhr ]
|
||||
then
|
||||
echo "Running hhsearch"
|
||||
HH="hhsearch -b 50 -B 500 -z 50 -Z 500 -mact 0.05 -cpu $CPU -maxmem $MEM -aliw 100000 -e 100 -p 5.0 -d $DB"
|
||||
cat $WDIR/t000_.ss2 $WDIR/t000_.msa0.a3m > $WDIR/t000_.msa0.ss2.a3m
|
||||
$HH -i $WDIR/t000_.msa0.ss2.a3m -o $WDIR/t000_.hhr -atab $WDIR/t000_.atab -v 0 > $WDIR/log/hhsearch.stdout 2> $WDIR/log/hhsearch.stderr
|
||||
fi
|
||||
|
||||
|
||||
############################################################
|
||||
# 4. end-to-end prediction
|
||||
############################################################
|
||||
if [ ! -s $WDIR/t000_.3track.npz ]
|
||||
then
|
||||
echo "Running end-to-end prediction"
|
||||
python $PIPEDIR/network/predict_e2e.py \
|
||||
-m $PIPEDIR/weights \
|
||||
-i $WDIR/t000_.msa0.a3m \
|
||||
-o $WDIR/t000_.e2e \
|
||||
--hhr $WDIR/t000_.hhr \
|
||||
--atab $WDIR/t000_.atab \
|
||||
--db $DB 1> $WDIR/log/network.stdout 2> $WDIR/log/network.stderr
|
||||
fi
|
||||
echo "Done"
|
|
@ -0,0 +1,17 @@
|
|||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import argparse
|
||||
from pipeline_utils import execute_pipeline
|
||||
|
||||
PARSER = argparse.ArgumentParser()
|
||||
PARSER.add_argument('sequence', type=str, help='A sequence or a path to a FASTA file')
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = PARSER.parse_args()
|
||||
execute_pipeline(args.sequence)
|
76
DGLPyTorch/DrugDiscovery/RoseTTAFold/run_inference_pipeline.sh
Executable file
76
DGLPyTorch/DrugDiscovery/RoseTTAFold/run_inference_pipeline.sh
Executable file
|
@ -0,0 +1,76 @@
|
|||
#!/bin/bash
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
# make the script stop when error (non-true exit code) is occurred
|
||||
set -e
|
||||
CPU="32" # number of CPUs to use
|
||||
MEM="64" # max memory (in GB)
|
||||
|
||||
# Inputs:
|
||||
IN="$1" # input.fasta
|
||||
#WDIR=`realpath -s $2` # working folder
|
||||
DATABASES_DIR="${2:-/databases}"
|
||||
WEIGHTS_DIR="${3:-/weights}"
|
||||
WDIR="${4:-.}"
|
||||
|
||||
|
||||
|
||||
LEN=`tail -n1 $IN | wc -m`
|
||||
|
||||
mkdir -p $WDIR/logs
|
||||
|
||||
###########################################################
|
||||
# 1. generate MSAs
|
||||
############################################################
|
||||
if [ ! -s $WDIR/t000_.msa0.a3m ]
|
||||
then
|
||||
echo "Running HHblits - looking for the MSAs"
|
||||
/workspace/rf/input_prep/make_msa.sh $IN $WDIR $CPU $MEM $DATABASES_DIR > $WDIR/logs/make_msa.stdout 2> $WDIR/logs/make_msa.stderr
|
||||
fi
|
||||
|
||||
|
||||
############################################################
|
||||
# 2. predict secondary structure for HHsearch run
|
||||
############################################################
|
||||
if [ ! -s $WDIR/t000_.ss2 ]
|
||||
then
|
||||
echo "Running PSIPRED - looking for the Secondary Structures"
|
||||
/workspace/rf/input_prep/make_ss.sh $WDIR/t000_.msa0.a3m $WDIR/t000_.ss2 > $WDIR/logs/make_ss.stdout 2> $WDIR/logs/make_ss.stderr
|
||||
fi
|
||||
|
||||
|
||||
############################################################
|
||||
# 3. search for templates
|
||||
############################################################
|
||||
if [ ! -s $WDIR/t000_.hhr ]
|
||||
then
|
||||
echo "Running hhsearch - looking for the Templates"
|
||||
/workspace/rf/input_prep/prepare_templates.sh $WDIR $CPU $MEM $DATABASES_DIR > $WDIR/logs/prepare_templates.stdout 2> $WDIR/logs/prepare_templates.stderr
|
||||
fi
|
||||
|
||||
|
||||
############################################################
|
||||
# 4. end-to-end prediction
|
||||
############################################################
|
||||
DB="$DATABASES_DIR/pdb100_2021Mar03/pdb100_2021Mar03"
|
||||
if [ ! -s $WDIR/t000_.3track.npz ]
|
||||
then
|
||||
echo "Running end-to-end prediction"
|
||||
python /workspace/rf/network/predict_e2e.py \
|
||||
-m $WEIGHTS_DIR \
|
||||
-i $WDIR/t000_.msa0.a3m \
|
||||
-o /results/output.e2e \
|
||||
--hhr $WDIR/t000_.hhr \
|
||||
--atab $WDIR/t000_.atab \
|
||||
--db $DB
|
||||
fi
|
||||
echo "Done."
|
||||
echo "Output saved as /results/output.e2e.pdb"
|
||||
|
||||
# 1> $WDIR/log/network.stdout 2> $WDIR/log/network.stderr
|
123
DGLPyTorch/DrugDiscovery/RoseTTAFold/run_pyrosetta_ver.sh
Executable file
123
DGLPyTorch/DrugDiscovery/RoseTTAFold/run_pyrosetta_ver.sh
Executable file
|
@ -0,0 +1,123 @@
|
|||
#!/bin/bash
|
||||
|
||||
# make the script stop when error (non-true exit code) is occured
|
||||
set -e
|
||||
|
||||
############################################################
|
||||
# >>> conda initialize >>>
|
||||
# !! Contents within this block are managed by 'conda init' !!
|
||||
__conda_setup="$('conda' 'shell.bash' 'hook' 2> /dev/null)"
|
||||
eval "$__conda_setup"
|
||||
unset __conda_setup
|
||||
# <<< conda initialize <<<
|
||||
############################################################
|
||||
|
||||
SCRIPT=`realpath -s $0`
|
||||
export PIPEDIR=`dirname $SCRIPT`
|
||||
|
||||
CPU="8" # number of CPUs to use
|
||||
MEM="64" # max memory (in GB)
|
||||
|
||||
# Inputs:
|
||||
IN="$1" # input.fasta
|
||||
WDIR=`realpath -s $2` # working folder
|
||||
|
||||
|
||||
LEN=`tail -n1 $IN | wc -m`
|
||||
|
||||
mkdir -p $WDIR/log
|
||||
|
||||
conda activate RoseTTAFold
|
||||
############################################################
|
||||
# 1. generate MSAs
|
||||
############################################################
|
||||
if [ ! -s $WDIR/t000_.msa0.a3m ]
|
||||
then
|
||||
echo "Running HHblits"
|
||||
$PIPEDIR/input_prep/make_msa.sh $IN $WDIR $CPU $MEM > $WDIR/log/make_msa.stdout 2> $WDIR/log/make_msa.stderr
|
||||
fi
|
||||
|
||||
|
||||
############################################################
|
||||
# 2. predict secondary structure for HHsearch run
|
||||
############################################################
|
||||
if [ ! -s $WDIR/t000_.ss2 ]
|
||||
then
|
||||
echo "Running PSIPRED"
|
||||
$PIPEDIR/input_prep/make_ss.sh $WDIR/t000_.msa0.a3m $WDIR/t000_.ss2 > $WDIR/log/make_ss.stdout 2> $WDIR/log/make_ss.stderr
|
||||
fi
|
||||
|
||||
|
||||
############################################################
|
||||
# 3. search for templates
|
||||
############################################################
|
||||
DB="$PIPEDIR/pdb100_2021Mar03/pdb100_2021Mar03"
|
||||
if [ ! -s $WDIR/t000_.hhr ]
|
||||
then
|
||||
echo "Running hhsearch"
|
||||
HH="hhsearch -b 50 -B 500 -z 50 -Z 500 -mact 0.05 -cpu $CPU -maxmem $MEM -aliw 100000 -e 100 -p 5.0 -d $DB"
|
||||
cat $WDIR/t000_.ss2 $WDIR/t000_.msa0.a3m > $WDIR/t000_.msa0.ss2.a3m
|
||||
$HH -i $WDIR/t000_.msa0.ss2.a3m -o $WDIR/t000_.hhr -atab $WDIR/t000_.atab -v 0 > $WDIR/log/hhsearch.stdout 2> $WDIR/log/hhsearch.stderr
|
||||
fi
|
||||
|
||||
|
||||
############################################################
|
||||
# 4. predict distances and orientations
|
||||
############################################################
|
||||
if [ ! -s $WDIR/t000_.3track.npz ]
|
||||
then
|
||||
echo "Predicting distance and orientations"
|
||||
python $PIPEDIR/network/predict_pyRosetta.py \
|
||||
-m $PIPEDIR/weights \
|
||||
-i $WDIR/t000_.msa0.a3m \
|
||||
-o $WDIR/t000_.3track \
|
||||
--hhr $WDIR/t000_.hhr \
|
||||
--atab $WDIR/t000_.atab \
|
||||
--db $DB 1> $WDIR/log/network.stdout 2> $WDIR/log/network.stderr
|
||||
fi
|
||||
|
||||
############################################################
|
||||
# 5. perform modeling
|
||||
############################################################
|
||||
mkdir -p $WDIR/pdb-3track
|
||||
|
||||
conda deactivate
|
||||
conda activate folding
|
||||
|
||||
for m in 0 1 2
|
||||
do
|
||||
for p in 0.05 0.15 0.25 0.35 0.45
|
||||
do
|
||||
for ((i=0;i<1;i++))
|
||||
do
|
||||
if [ ! -f $WDIR/pdb-3track/model${i}_${m}_${p}.pdb ]; then
|
||||
echo "python -u $PIPEDIR/folding/RosettaTR.py --roll -r 3 -pd $p -m $m -sg 7,3 $WDIR/t000_.3track.npz $IN $WDIR/pdb-3track/model${i}_${m}_${p}.pdb"
|
||||
fi
|
||||
done
|
||||
done
|
||||
done > $WDIR/parallel.fold.list
|
||||
|
||||
N=`cat $WDIR/parallel.fold.list | wc -l`
|
||||
if [ "$N" -gt "0" ]; then
|
||||
echo "Running parallel RosettaTR.py"
|
||||
parallel -j $CPU < $WDIR/parallel.fold.list > $WDIR/log/folding.stdout 2> $WDIR/log/folding.stderr
|
||||
fi
|
||||
|
||||
############################################################
|
||||
# 6. Pick final models
|
||||
############################################################
|
||||
count=$(find $WDIR/pdb-3track -maxdepth 1 -name '*.npz' | grep -v 'features' | wc -l)
|
||||
if [ "$count" -lt "15" ]; then
|
||||
# run DeepAccNet-msa
|
||||
echo "Running DeepAccNet-msa"
|
||||
python $PIPEDIR/DAN-msa/ErrorPredictorMSA.py --roll -p $CPU $WDIR/t000_.3track.npz $WDIR/pdb-3track $WDIR/pdb-3track 1> $WDIR/log/DAN_msa.stdout 2> $WDIR/log/DAN_msa.stderr
|
||||
fi
|
||||
|
||||
if [ ! -s $WDIR/model/model_5.crderr.pdb ]
|
||||
then
|
||||
echo "Picking final models"
|
||||
python -u -W ignore $PIPEDIR/DAN-msa/pick_final_models.div.py \
|
||||
$WDIR/pdb-3track $WDIR/model $CPU > $WDIR/log/pick.stdout 2> $WDIR/log/pick.stderr
|
||||
echo "Final models saved in: $2/model"
|
||||
fi
|
||||
echo "Done"
|
|
@ -0,0 +1,33 @@
|
|||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
WEIGHTS_DIR="${1:-.}"
|
||||
DATABASES_DIR="${2:-./databases}"
|
||||
|
||||
mkdir -p $DATABASES_DIR
|
||||
|
||||
echo "Downloading pre-trained model weights [1G]"
|
||||
wget https://files.ipd.uw.edu/pub/RoseTTAFold/weights.tar.gz
|
||||
tar xfz weights.tar.gz -C $WEIGHTS_DIR
|
||||
|
||||
|
||||
# uniref30 [46G]
|
||||
echo "Downloading UniRef30_2020_06 [46G]"
|
||||
wget http://wwwuser.gwdg.de/~compbiol/uniclust/2020_06/UniRef30_2020_06_hhsuite.tar.gz -P $DATABASES_DIR
|
||||
mkdir -p $DATABASES_DIR/UniRef30_2020_06
|
||||
tar xfz $DATABASES_DIR/UniRef30_2020_06_hhsuite.tar.gz -C $DATABASES_DIR/UniRef30_2020_06
|
||||
|
||||
# BFD [272G]
|
||||
#wget https://bfd.mmseqs.com/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz
|
||||
#mkdir -p bfd
|
||||
#tar xfz bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz -C ./bfd
|
||||
|
||||
# structure templates (including *_a3m.ffdata, *_a3m.ffindex) [over 100G]
|
||||
echo "Downloading pdb100_2021Mar03 [over 100G]"
|
||||
wget https://files.ipd.uw.edu/pub/RoseTTAFold/pdb100_2021Mar03.tar.gz -P $DATABASES_DIR
|
||||
tar xfz $DATABASES_DIR/pdb100_2021Mar03.tar.gz -C $DATABASES_DIR/
|
11
DGLPyTorch/DrugDiscovery/RoseTTAFold/start_jupyter.sh
Normal file
11
DGLPyTorch/DrugDiscovery/RoseTTAFold/start_jupyter.sh
Normal file
|
@ -0,0 +1,11 @@
|
|||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
PORT="${1:-6006}"
|
||||
|
||||
jupyter lab --ip=0.0.0.0 --allow-root --no-browser --NotebookApp.token='' --notebook-dir=/workspace/rf --NotebookApp.allow_origin='*' --port $PORT
|
Loading…
Reference in a new issue