[RoseTTAFold] Initial release

This commit is contained in:
michalm 2021-10-10 19:35:57 +02:00 committed by Andrei Shumak
parent 26d8955cc5
commit 0db746b4ab
83 changed files with 8446 additions and 0 deletions

View file

@ -0,0 +1,73 @@
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.09-py3
FROM ${FROM_IMAGE_NAME} AS dgl_builder
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update \
&& apt-get install -y git build-essential python3-dev make cmake \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /dgl
RUN git clone --branch v0.7.0 --recurse-submodules --depth 1 https://github.com/dmlc/dgl.git .
RUN sed -i 's/"35 50 60 70"/"60 70 80"/g' cmake/modules/CUDA.cmake
WORKDIR build
RUN cmake -DUSE_CUDA=ON -DUSE_FP16=ON ..
RUN make -j8
FROM ${FROM_IMAGE_NAME}
# VERY IMPORTANT, DO NOT REMOVE:
ENV FORCE_CUDA 1
RUN pip install -v torch-geometric
RUN pip install -v torch-scatter
RUN pip install -v torch-sparse
RUN pip install -v torch-cluster
RUN pip install -v torch-spline-conv
# copy built DGL and install it
COPY --from=dgl_builder /dgl ./dgl
RUN cd dgl/python && python setup.py install && cd ../.. && rm -rf dgl
ENV DGLBACKEND=pytorch
#RUN pip install dgl-cu111 -f https://data.dgl.ai/wheels/repo.html
# HH-Suite
RUN git clone https://github.com/soedinglab/hh-suite.git && \
mkdir -p hh-suite/build
WORKDIR hh-suite/build
RUN cmake .. && \
make && \
make install
# PSIPRED
WORKDIR /workspace
RUN wget http://wwwuser.gwdg.de/~compbiol/data/csblast/releases/csblast-2.2.3_linux64.tar.gz -O csblast-2.2.3.tar.gz && \
mkdir -p csblast-2.2.3 && \
tar xf csblast-2.2.3.tar.gz -C csblast-2.2.3 --strip-components=1 && \
rm csblast-2.2.3.tar.gz
RUN wget https://ftp.ncbi.nlm.nih.gov/blast/executables/legacy.NOTSUPPORTED/2.2.26/blast-2.2.26-x64-linux.tar.gz && \
tar xf blast-2.2.26-x64-linux.tar.gz && \
rm blast-2.2.26-x64-linux.tar.gz
RUN wget http://bioinfadmin.cs.ucl.ac.uk/downloads/psipred/psipred.4.02.tar.gz && \
tar xf psipred.4.02.tar.gz && \
rm psipred.4.02.tar.gz
ADD . /workspace/rf
WORKDIR /workspace/rf
RUN wget https://openstructure.org/static/lddt-linux.zip -O lddt.zip && unzip -d lddt -j lddt.zip
RUN pip install --upgrade pip
RUN pip install -r requirements.txt

View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2021 RosettaCommons
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,94 @@
# *RoseTTAFold*
This package contains deep learning models and related scripts to run RoseTTAFold.
This repository is the official implementation of RoseTTAFold: Accurate prediction of protein structures and interactions using a 3-track network.
## Installation
1. Clone the package
```
git clone https://github.com/RosettaCommons/RoseTTAFold.git
cd RoseTTAFold
```
2. Create conda environment using `RoseTTAFold-linux.yml` file and `folding-linux.yml` file. The latter is required to run a pyrosetta version only (run_pyrosetta_ver.sh).
```
# create conda environment for RoseTTAFold
# If your NVIDIA driver compatible with cuda11
conda env create -f RoseTTAFold-linux.yml
# If not (but compatible with cuda10)
conda env create -f RoseTTAFold-linux-cu101.yml
# create conda environment for pyRosetta folding & running DeepAccNet
conda env create -f folding-linux.yml
```
3. Download network weights (under Rosetta-DL Software license -- please see below)
While the code is licensed under the MIT License, the trained weights and data for RoseTTAFold are made available for non-commercial use only under the terms of the Rosetta-DL Software license. You can find details at https://files.ipd.uw.edu/pub/RoseTTAFold/Rosetta-DL_LICENSE.txt
```
wget https://files.ipd.uw.edu/pub/RoseTTAFold/weights.tar.gz
tar xfz weights.tar.gz
```
4. Download and install third-party software.
```
./install_dependencies.sh
```
5. Download sequence and structure databases
```
# uniref30 [46G]
wget http://wwwuser.gwdg.de/~compbiol/uniclust/2020_06/UniRef30_2020_06_hhsuite.tar.gz
mkdir -p UniRef30_2020_06
tar xfz UniRef30_2020_06_hhsuite.tar.gz -C ./UniRef30_2020_06
# BFD [272G]
wget https://bfd.mmseqs.com/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz
mkdir -p bfd
tar xfz bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz -C ./bfd
# structure templates (including *_a3m.ffdata, *_a3m.ffindex) [over 100G]
wget https://files.ipd.uw.edu/pub/RoseTTAFold/pdb100_2021Mar03.tar.gz
tar xfz pdb100_2021Mar03.tar.gz
# for CASP14 benchmarks, we used this one: https://files.ipd.uw.edu/pub/RoseTTAFold/pdb100_2020Mar11.tar.gz
```
6. Obtain a [PyRosetta licence](https://els2.comotion.uw.edu/product/pyrosetta) and install the package in the newly created `folding` conda environment ([link](http://www.pyrosetta.org/downloads)).
## Usage
```
# For monomer structure prediction
cd example
../run_[pyrosetta, e2e]_ver.sh input.fa .
# For complex modeling
# please see README file under example/complex_modeling/README for details.
python network/predict_complex.py -i paired.a3m -o complex -Ls 218 310
```
## Expected outputs
For the pyrosetta version, user will get five final models having estimated CA rms error at the B-factor column (model/model_[1-5].crderr.pdb).
For the end-to-end version, there will be a single PDB output having estimated residue-wise CA-lddt at the B-factor column (t000_.e2e.pdb).
## FAQ
1. Segmentation fault while running hhblits/hhsearch
For easy install, we used a statically compiled version of hhsuite (installed through conda). Currently, we're not sure what exactly causes segmentation fault error in some cases, but we found that it might be resolved if you compile hhsuite from source and use this compiled version instead of conda version. For installation of hhsuite, please see [here](https://github.com/soedinglab/hh-suite).
2. Submitting jobs to computing nodes
The modeling pipeline provided here (run_pyrosetta_ver.sh/run_e2e_ver.sh) is a kind of guidelines to show how RoseTTAFold works. For more efficient use of computing resources, you might want to modify the provided bash script to submit separate jobs with proper dependencies for each of steps (more cpus/memory for hhblits/hhsearch, using gpus only for running the networks, etc).
## Links:
* [Robetta server](https://robetta.bakerlab.org/) (RoseTTAFold option)
* [RoseTTAFold models for CASP14 targets](https://files.ipd.uw.edu/pub/RoseTTAFold/casp14_models.tar.gz) [input MSA and hhsearch files are included]
## Credit to performer-pytorch and SE(3)-Transformer codes
The code in the network/performer_pytorch.py is strongly based on [this repo](https://github.com/lucidrains/performer-pytorch) which is pytorch implementation of [Performer architecture](https://arxiv.org/abs/2009.14794).
The codes in network/equivariant_attention is from the original SE(3)-Transformer [repo](https://github.com/FabianFuchsML/se3-transformer-public) which accompanies [the paper](https://arxiv.org/abs/2006.10503) 'SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks' by Fabian et al.
## References
M Baek, et al., Accurate prediction of protein structures and interactions using a 3-track network, bioRxiv (2021). [link](https://www.biorxiv.org/content/10.1101/2021.06.14.448402v1)

View file

@ -0,0 +1,138 @@
# RoseTTAFold for PyTorch
This repository provides a script to run inference using the RoseTTAFold model. The content of this repository is tested and maintained by NVIDIA.
## Table Of Contents
- [Model overview](#model-overview)
* [Model architecture](#model-architecture)
- [Setup](#setup)
* [Requirements](#requirements)
- [Quick Start Guide](#quick-start-guide)
- [Release notes](#release-notes)
* [Changelog](#changelog)
* [Known issues](#known-issues)
## Model overview
The RoseTTAFold is a model designed to provide accurate protein structure from its amino acid sequence. This model is
based on [Accurate prediction of protein structures and interactions using a 3-track network](https://www.biorxiv.org/content/10.1101/2021.06.14.448402v1) by Minkyung Baek et al.
This implementation is a dockerized version of the official [RoseTTAFold repository](https://github.com/RosettaCommons/RoseTTAFold/).
Here you can find the [original RoseTTAFold guide](README-ROSETTAFOLD.md).
### Model architecture
The RoseTTAFold model is based on a 3-track architecture fusing 1D, 2D, and 3D information about the protein structure.
All information is exchanged between tracks to learn the sequence and coordinate patterns at the same time. The final prediction
is refined using an SE(3)-Transformer.
<img src="images/NetworkArchitecture.jpg" width="900"/>
*Figure 1: The RoseTTAFold architecture. Image comes from the [original paper](https://www.biorxiv.org/content/10.1101/2021.06.14.448402v1).*
## Setup
The following section lists the requirements that you need to meet in order to run inference using the RoseTTAFold model.
### Requirements
This repository contains a Dockerfile that extends the PyTorch NGC container and encapsulates necessary dependencies. Aside from these dependencies, ensure you have the following components:
- [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
- PyTorch 21.09-py3 NGC container
- Supported GPUs:
- [NVIDIA Volta architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
- [NVIDIA Turing architecture](https://www.nvidia.com/en-us/design-visualization/technologies/turing-architecture/)
- [NVIDIA Ampere architecture](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/)
For more information about how to get started with NGC containers, refer to the following sections from the NVIDIA GPU Cloud Documentation and the Deep Learning Documentation:
- [Getting Started Using NVIDIA GPU Cloud](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html)
- [Accessing And Pulling From The NGC Container Registry](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#accessing_registry)
- [Running PyTorch](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/running.html#running)
For those unable to use the PyTorch NGC container, to set up the required environment or create your own container, refer to the versioned [NVIDIA Container Support Matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html).
In addition, 1 TB of disk space is required to unpack the required databases.
## Quick Start Guide
To run inference using the RoseTTAFold model, perform the following steps using the default parameters.
1. Clone the repository.
```
git clone https://github.com/NVIDIA/DeepLearningExamples
cd DeepLearningExamples/DGLPyTorch/
```
2. Download the pre-trained weights and databases needed for inference.
The following command downloads the pre-trained weights and two databases needed to create derived features to the input to the model.
The script will download the `UniRef30` (~50 GB) and `pdb100_2021Mar03` (~115 GB) databases, which might take a considerable amount
of time. Additionally, unpacking those databases requires approximately 1 TB of free disk space.
By default, the data will be downloaded to `./weights` and `./databases` folders in the current directory.
```
bash scripts/download_databases.sh
```
If you would like to specify the download location you can pass the following parameters
```
bash scripts/download_databases.sh PATH-TO-WEIGHTS PATH-TO-DATABASES
```
3. Build the RoseTTAFold PyTorch NGC container. This step builds the PyTorch dependencies on your machine and can take between 30 minutes and 1 hour to complete.
```
docker build -t rosettafold .
```
4. Start an interactive session in the NGC container to run inference.
The following command launches the container and mount the `PATH-TO-WEIGHTS` directory as a volume to the `/weights` directory in the container, the `PATH-TO-DATABASES` directory as a volume to the `/databases` directory in the container, and `./results` directory to the `/results` directory in the container.
```
mkdir data results
docker run --ipc=host -it --rm --runtime=nvidia -p6006:6006 -v PATH-TO-WEIGHTS:/weights -v PATH-TO-DATABASES:/databases -v ${PWD}/results:/results rosettafold:latest /bin/bash
```
5. Start inference/predictions.
To run inference you have to prepare a FASTA file and pass a path to it or pass a sequence directly.
```
python run_inference_pipeline.py [Sequence]
```
There is an example FASTA file at `example/input.fa` for you to try. Running the inference pipeline consists of four steps:
1. Preparing the Multiple Sequence Alignments (MSAs)
2. Preparing the secondary structures
3. Preparing the templates
4. Iteratively refining the prediction
The first three steps can take between a couple of minutes and an hour, depending on the sequence.
The output will be stored at the `/results` directory as an `output.e2e.pdb` file
6. Start Jupyter Notebook to run inference interactively.
To launch the application, copy the Notebook to the root folder.
```
cp notebooks/run_inference.ipynb .
```
To start Jupyter Notebook, run:
```
jupyter notebook run_inference.ipynb
```
For more information about Jupyter Notebook, refer to the Jupyter Notebook documentation.
## Release notes
### Changelog
October 2021
- Initial release
### Known issues
There are no known issues with this model.

View file

@ -0,0 +1,107 @@
name: RoseTTAFold
channels:
- rusty1s
- pytorch
- nvidia
- conda-forge
- defaults
- bioconda
- biocore
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=4.5=1_gnu
- biopython=1.78=py38h497a2fe_2
- blas=1.0=mkl
- hhsuite
- blast-legacy=2.2.26=2
- brotlipy=0.7.0=py38h497a2fe_1001
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2021.7.5=h06a4308_1
- certifi=2021.5.30=py38h06a4308_0
- cffi=1.14.5=py38ha65f79e_0
- chardet=4.0.0=py38h578d9bd_1
- cryptography=3.4.7=py38ha5dfef3_0
- cudatoolkit=10.2.89=h8f6ccaa_8
- ffmpeg=4.3=hf484d3e_0
- freetype=2.10.4=h5ab3b9f_0
- gmp=6.2.1=h2531618_2
- gnutls=3.6.15=he1e5248_0
- googledrivedownloader=0.4=pyhd3deb0d_1
- idna=2.10=pyh9f0ad1d_0
- intel-openmp=2021.2.0=h06a4308_610
- jinja2=3.0.1=pyhd8ed1ab_0
- joblib=1.0.1=pyhd8ed1ab_0
- jpeg=9b=h024ee3a_2
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.35.1=h7274673_9
- libffi=3.3=he6710b0_2
- libgcc-ng=9.3.0=h5101ec6_17
- libgfortran-ng=7.5.0=h14aa051_19
- libgfortran4=7.5.0=h14aa051_19
- libgomp=9.3.0=h5101ec6_17
- libiconv=1.15=h63c8f33_5
- libidn2=2.3.1=h27cfd23_0
- libpng=1.6.37=hbc83047_0
- libstdcxx-ng=9.3.0=hd4cf53a_17
- libtasn1=4.16.0=h27cfd23_0
- libtiff=4.2.0=h85742a9_0
- libunistring=0.9.10=h27cfd23_0
- libuv=1.40.0=h7b6447c_0
- libwebp-base=1.2.0=h27cfd23_0
- lz4-c=1.9.3=h2531618_0
- markupsafe=2.0.1=py38h497a2fe_0
- mkl=2021.2.0=h06a4308_296
- mkl-service=2.3.0=py38h27cfd23_1
- mkl_fft=1.3.0=py38h42c9631_2
- mkl_random=1.2.1=py38ha9443f7_2
- ncurses=6.2=he6710b0_1
- nettle=3.7.3=hbbd107a_1
- networkx=2.5=py_0
- ninja=1.10.2=hff7bd54_1
- numpy=1.20.2=py38h2d18471_0
- numpy-base=1.20.2=py38hfae3a4d_0
- olefile=0.46=py_0
- openh264=2.1.0=hd408876_0
- openssl=1.1.1k=h27cfd23_0
- packaging=20.9=pyhd3eb1b0_0
- pandas=1.2.5=py38h1abd341_0
- pillow=8.2.0=py38he98fc37_0
- pip=21.1.3=py38h06a4308_0
- psipred=4.01=1
- pycparser=2.20=pyh9f0ad1d_2
- pyopenssl=20.0.1=pyhd8ed1ab_0
- pyparsing=2.4.7=pyh9f0ad1d_0
- pysocks=1.7.1=py38h578d9bd_3
- python=3.8.10=h12debd9_8
- python-dateutil=2.8.1=py_0
- python-louvain=0.15=pyhd3deb0d_0
- python_abi=3.8=2_cp38
- pytorch=1.8.1=py3.8_cuda10.2_cudnn7.6.5_0
- pytorch-cluster=1.5.9=py38_torch_1.8.0_cu102
- pytorch-geometric=1.7.2=py38_torch_1.8.0_cu102
- pytorch-scatter=2.0.7=py38_torch_1.8.0_cu102
- pytorch-sparse=0.6.10=py38_torch_1.8.0_cu102
- pytorch-spline-conv=1.2.1=py38_torch_1.8.0_cu102
- pytz=2021.1=pyhd8ed1ab_0
- readline=8.1=h27cfd23_0
- requests=2.25.1=pyhd3deb0d_0
- scikit-learn=0.24.2=py38ha9443f7_0
- setuptools=52.0.0=py38h06a4308_0
- six=1.16.0=pyhd3eb1b0_0
- sqlite=3.36.0=hc218d9a_0
- threadpoolctl=2.1.0=pyh5ca1d4c_0
- tk=8.6.10=hbc83047_0
- torchvision=0.9.1=py38_cu102
- tqdm=4.61.1=pyhd8ed1ab_0
- typing_extensions=3.10.0.0=pyh06a4308_0
- urllib3=1.26.6=pyhd8ed1ab_0
- wheel=0.36.2=pyhd3eb1b0_0
- xz=5.2.5=h7b6447c_0
- zlib=1.2.11=h7b6447c_3
- zstd=1.4.9=haebb681_0
- pip:
- decorator==4.4.2
- dgl-cu102==0.6.1
- lie-learn==0.0.1.post1
- scipy==1.7.0

View file

@ -0,0 +1,108 @@
name: RoseTTAFold
channels:
- rusty1s
- pytorch
- nvidia
- conda-forge
- defaults
- bioconda
- biocore
dependencies:
- biopython=1.78
- biocore::blast-legacy=2.2.26
- hhsuite
- psipred=4.01
- _libgcc_mutex=0.1=main
- _openmp_mutex=4.5=1_gnu
- blas=1.0=mkl
- brotlipy=0.7.0=py38h497a2fe_1001
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2021.5.25=h06a4308_1
- certifi=2021.5.30=py38h06a4308_0
- cffi=1.14.5=py38ha65f79e_0
- chardet=4.0.0=py38h578d9bd_1
- cryptography=3.4.7=py38ha5dfef3_0
- cudatoolkit=11.1.74=h6bb024c_0
- ffmpeg=4.3=hf484d3e_0
- freetype=2.10.4=h5ab3b9f_0
- gmp=6.2.1=h2531618_2
- gnutls=3.6.15=he1e5248_0
- googledrivedownloader=0.4=pyhd3deb0d_1
- idna=2.10=pyh9f0ad1d_0
- intel-openmp=2021.2.0=h06a4308_610
- jinja2=3.0.1=pyhd8ed1ab_0
- joblib=1.0.1=pyhd8ed1ab_0
- jpeg=9b=h024ee3a_2
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.35.1=h7274673_9
- libffi=3.3=he6710b0_2
- libgcc-ng=9.3.0=h5101ec6_17
- libgfortran-ng=7.5.0=h14aa051_19
- libgfortran4=7.5.0=h14aa051_19
- libgomp=9.3.0=h5101ec6_17
- libiconv=1.15=h63c8f33_5
- libidn2=2.3.1=h27cfd23_0
- libpng=1.6.37=hbc83047_0
- libstdcxx-ng=9.3.0=hd4cf53a_17
- libtasn1=4.16.0=h27cfd23_0
- libtiff=4.2.0=h85742a9_0
- libunistring=0.9.10=h27cfd23_0
- libuv=1.40.0=h7b6447c_0
- libwebp-base=1.2.0=h27cfd23_0
- lz4-c=1.9.3=h2531618_0
- markupsafe=2.0.1=py38h497a2fe_0
- mkl=2021.2.0=h06a4308_296
- mkl-service=2.3.0=py38h27cfd23_1
- mkl_fft=1.3.0=py38h42c9631_2
- mkl_random=1.2.1=py38ha9443f7_2
- ncurses=6.2=he6710b0_1
- nettle=3.7.3=hbbd107a_1
- networkx=2.5=py_0
- ninja=1.10.2=hff7bd54_1
- numpy=1.20.2=py38h2d18471_0
- numpy-base=1.20.2=py38hfae3a4d_0
- olefile=0.46=py_0
- openh264=2.1.0=hd408876_0
- openssl=1.1.1k=h27cfd23_0
- packaging=20.9=pyhd3eb1b0_0
- pandas=1.2.5=py38h1abd341_0
- pillow=8.2.0=py38he98fc37_0
- pip=21.1.3=py38h06a4308_0
- pycparser=2.20=pyh9f0ad1d_2
- pyopenssl=20.0.1=pyhd8ed1ab_0
- pyparsing=2.4.7=pyh9f0ad1d_0
- pysocks=1.7.1=py38h578d9bd_3
- python=3.8.10=h12debd9_8
- python-dateutil=2.8.1=py_0
- python-louvain=0.15=pyhd3deb0d_0
- python_abi=3.8=2_cp38
- pytorch=1.9.0=py3.8_cuda11.1_cudnn8.0.5_0
- pytorch-cluster=1.5.9=py38_torch_1.9.0_cu111
- pytorch-geometric=1.7.2=py38_torch_1.9.0_cu111
- pytorch-scatter=2.0.7=py38_torch_1.9.0_cu111
- pytorch-sparse=0.6.10=py38_torch_1.9.0_cu111
- pytorch-spline-conv=1.2.1=py38_torch_1.9.0_cu111
- pytz=2021.1=pyhd8ed1ab_0
- readline=8.1=h27cfd23_0
- requests=2.25.1=pyhd3deb0d_0
- scikit-learn=0.24.2=py38ha9443f7_0
- setuptools=52.0.0=py38h06a4308_0
- six=1.16.0=pyhd3eb1b0_0
- sqlite=3.36.0=hc218d9a_0
- threadpoolctl=2.1.0=pyh5ca1d4c_0
- tk=8.6.10=hbc83047_0
- torchaudio=0.9.0=py38
- torchvision=0.10.0=py38_cu111
- tqdm=4.61.1=pyhd8ed1ab_0
- typing_extensions=3.10.0.0=pyh06a4308_0
- urllib3=1.26.6=pyhd8ed1ab_0
- wheel=0.36.2=pyhd3eb1b0_0
- xz=5.2.5=h7b6447c_0
- zlib=1.2.11=h7b6447c_3
- zstd=1.4.9=haebb681_0
- pip:
- decorator==4.4.2
- dgl-cu110==0.6.1
- scipy==1.7.0
- lie-learn==0.0.1.post1

View file

@ -0,0 +1,2 @@
>T1078 Tsp1, Trichoderma virens, 138 residues|
MAAPTPADKSMMAAVPEWTITNLKRVCNAGNTSCTWTFGVDTHLATATSCTYVVKANANASQASGGPVTCGPYTITSSWSGQFGPNNGFTTFAVTDFSKKLIVWPAYTDVQVQAGKVVSPNQSYAPANLPLEHHHHHH

View file

@ -0,0 +1,9 @@
name: folding
channels:
- defaults
- conda-forge
dependencies:
- tensorflow-gpu=1.14
- pandas
- scikit-learn=0.24
- parallel

View file

@ -0,0 +1,43 @@
###############################################################################
RosettaTR.py: Relax on dualspace (2 rounds of relax w/ different set of
restraints)
###############################################################################
usage: RosettaTR.py [-h] [-r NRESTARTS] [-pd PCUT] [-m {0,1,2}] [-bb BB]
[-sg SG] [-n STEPS] [--save_chk] [--orient] [--no-orient]
[--fastrelax] [--no-fastrelax] [--roll] [--no-roll]
NPZ FASTA OUT
positional arguments:
NPZ input distograms and anglegrams (NN predictions)
FASTA input sequence
OUT output model (in PDB format)
optional arguments:
-h, --help show this help message and exit
-r NRESTARTS number of noisy restrarts (default: 3)
-pd PCUT min probability of distance restraints (default: 0.05)
-m {0,1,2} 0: sh+m+l, 1: (sh+m)+l, 2: (sh+m+l) (default: 2)
-bb BB predicted backbone torsions (default: )
-sg SG window size and order for a Savitzky-Golay filter (comma-
separated) (default: )
-n STEPS number of minimization steps (default: 1000)
--save_chk save checkpoint files to restart (default: False)
--orient use orientations (default: True)
--no-orient
--fastrelax perform FastRelax (default: True)
--no-fastrelax
--roll circularly shift 6d coordinate arrays by 1 (default: False)
--no-roll
# USAGE
conda activate folding
# try: -m 0,1,2
# -pd 0.05, 0.15, 0.25, 0.35, 0.45
# repeat ~3-5 times for every combination of -m and -pd
# !!! use '--roll' option if no-contact bin is the last one !!!
python ./RosettaTR.py -m 0 -pd 0.15 fake.npz fake.fa model.pdb

View file

@ -0,0 +1,340 @@
import sys,os,json
import tempfile
import numpy as np
from arguments import *
from utils import *
from pyrosetta import *
from pyrosetta.rosetta.protocols.minimization_packing import MinMover
vdw_weight = {0: 3.0, 1: 5.0, 2: 10.0}
rsr_dist_weight = {0: 3.0, 1: 2.0, 3: 1.0}
rsr_orient_weight = {0: 1.0, 1: 1.0, 3: 0.5}
def main():
########################################################
# process inputs
########################################################
# read params
scriptdir = os.path.dirname(os.path.realpath(__file__))
with open(scriptdir + '/data/params.json') as jsonfile:
params = json.load(jsonfile)
# get command line arguments
args = get_args(params)
print(args)
if os.path.exists(args.OUT):
return
# init PyRosetta
init_cmd = list()
init_cmd.append("-multithreading:interaction_graph_threads 1 -multithreading:total_threads 1")
init_cmd.append("-hb_cen_soft")
init_cmd.append("-detect_disulf -detect_disulf_tolerance 2.0") # detect disulfide bonds based on Cb-Cb distance (CEN mode) or SG-SG distance (FA mode)
init_cmd.append("-relax:dualspace true -relax::minimize_bond_angles -default_max_cycles 200")
init_cmd.append("-mute all")
init(" ".join(init_cmd))
# read and process restraints & sequence
seq = read_fasta(args.FASTA)
L = len(seq)
params['seq'] = seq
rst = gen_rst(params)
########################################################
# Scoring functions and movers
########################################################
sf = ScoreFunction()
sf.add_weights_from_file(scriptdir + '/data/scorefxn.wts')
sf1 = ScoreFunction()
sf1.add_weights_from_file(scriptdir + '/data/scorefxn1.wts')
sf_vdw = ScoreFunction()
sf_vdw.add_weights_from_file(scriptdir + '/data/scorefxn_vdw.wts')
sf_cart = ScoreFunction()
sf_cart.add_weights_from_file(scriptdir + '/data/scorefxn_cart.wts')
mmap = MoveMap()
mmap.set_bb(True)
mmap.set_chi(False)
mmap.set_jump(True)
min_mover1 = MinMover(mmap, sf1, 'lbfgs_armijo_nonmonotone', 0.001, True)
min_mover1.max_iter(1000)
min_mover_vdw = MinMover(mmap, sf_vdw, 'lbfgs_armijo_nonmonotone', 0.001, True)
min_mover_vdw.max_iter(500)
min_mover_cart = MinMover(mmap, sf_cart, 'lbfgs_armijo_nonmonotone', 0.000001, True)
min_mover_cart.max_iter(300)
min_mover_cart.cartesian(True)
if not os.path.exists("%s_before_relax.pdb"%('.'.join(args.OUT.split('.')[:-1]))):
########################################################
# initialize pose
########################################################
pose0 = pose_from_sequence(seq, 'centroid')
# mutate GLY to ALA
for i,a in enumerate(seq):
if a == 'G':
mutator = rosetta.protocols.simple_moves.MutateResidue(i+1,'ALA')
mutator.apply(pose0)
print('mutation: G%dA'%(i+1))
if (args.bb == ''):
print('setting random (phi,psi,omega)...')
set_random_dihedral(pose0)
else:
print('setting predicted (phi,psi,omega)...')
bb = np.load(args.bb)
set_predicted_dihedral(pose0,bb['phi'],bb['psi'],bb['omega'])
remove_clash(sf_vdw, min_mover_vdw, pose0)
Emin = 99999.9
########################################################
# minimization
########################################################
for run in range(params['NRUNS']):
# define repeat_mover here!! (update vdw weights: weak (1.0) -> strong (10.0)
sf.set_weight(rosetta.core.scoring.vdw, vdw_weight.setdefault(run, 10.0))
sf.set_weight(rosetta.core.scoring.atom_pair_constraint, rsr_dist_weight.setdefault(run, 1.0))
sf.set_weight(rosetta.core.scoring.dihedral_constraint, rsr_orient_weight.setdefault(run, 0.5))
sf.set_weight(rosetta.core.scoring.angle_constraint, rsr_orient_weight.setdefault(run, 0.5))
min_mover = MinMover(mmap, sf, 'lbfgs_armijo_nonmonotone', 0.001, True)
min_mover.max_iter(1000)
repeat_mover = RepeatMover(min_mover, 3)
#
pose = Pose()
pose.assign(pose0)
pose.remove_constraints()
if run > 0:
# diversify backbone
dphi = np.random.uniform(-10,10,L)
dpsi = np.random.uniform(-10,10,L)
for i in range(1,L+1):
pose.set_phi(i,pose.phi(i)+dphi[i-1])
pose.set_psi(i,pose.psi(i)+dpsi[i-1])
# remove clashes
remove_clash(sf_vdw, min_mover_vdw, pose)
# Save checkpoint
if args.save_chk:
pose.dump_pdb("%s_run%d_init.pdb"%('.'.join(args.OUT.split('.')[:-1]), run))
if args.mode == 0:
# short
print('short')
add_rst(pose, rst, 3, 12, params)
repeat_mover.apply(pose)
remove_clash(sf_vdw, min_mover1, pose)
min_mover_cart.apply(pose)
if args.save_chk:
pose.dump_pdb("%s_run%d_mode%d_step%d.pdb"%('.'.join(args.OUT.split('.')[:-1]), run, args.mode, 0))
# medium
print('medium')
add_rst(pose, rst, 12, 24, params)
repeat_mover.apply(pose)
remove_clash(sf_vdw, min_mover1, pose)
min_mover_cart.apply(pose)
if args.save_chk:
pose.dump_pdb("%s_run%d_mode%d_step%d.pdb"%('.'.join(args.OUT.split('.')[:-1]), run, args.mode, 1))
# long
print('long')
add_rst(pose, rst, 24, len(seq), params)
repeat_mover.apply(pose)
remove_clash(sf_vdw, min_mover1, pose)
min_mover_cart.apply(pose)
if args.save_chk:
pose.dump_pdb("%s_run%d_mode%d_step%d.pdb"%('.'.join(args.OUT.split('.')[:-1]), run, args.mode, 2))
elif args.mode == 1:
# short + medium
print('short + medium')
add_rst(pose, rst, 3, 24, params)
repeat_mover.apply(pose)
remove_clash(sf_vdw, min_mover1, pose)
min_mover_cart.apply(pose)
if args.save_chk:
pose.dump_pdb("%s_run%d_mode%d_step%d.pdb"%('.'.join(args.OUT.split('.')[:-1]), run, args.mode, 0))
# long
print('long')
add_rst(pose, rst, 24, len(seq), params)
repeat_mover.apply(pose)
remove_clash(sf_vdw, min_mover1, pose)
min_mover_cart.apply(pose)
if args.save_chk:
pose.dump_pdb("%s_run%d_mode%d_step%d.pdb"%('.'.join(args.OUT.split('.')[:-1]), run, args.mode, 1))
elif args.mode == 2:
# short + medium + long
print('short + medium + long')
add_rst(pose, rst, 3, len(seq), params)
repeat_mover.apply(pose)
remove_clash(sf_vdw, min_mover1, pose)
min_mover_cart.apply(pose)
if args.save_chk:
pose.dump_pdb("%s_run%d_mode%d_step%d.pdb"%('.'.join(args.OUT.split('.')[:-1]), run, args.mode, 0))
# check whether energy has decreased
pose.conformation().detect_disulfides() # detect disulfide bonds
E = sf_cart(pose)
if E < Emin:
print("Energy(iter=%d): %.1f --> %.1f (accept)"%(run, Emin, E))
Emin = E
pose0.assign(pose)
else:
print("Energy(iter=%d): %.1f --> %.1f (reject)"%(run, Emin, E))
# mutate ALA back to GLY
for i,a in enumerate(seq):
if a == 'G':
mutator = rosetta.protocols.simple_moves.MutateResidue(i+1,'GLY')
mutator.apply(pose0)
print('mutation: A%dG'%(i+1))
########################################################
# fix backbone geometry
########################################################
pose0.remove_constraints()
# apply more strict criteria to detect disulfide bond
# Set options for disulfide tolerance -> 1.0A
print (rosetta.basic.options.get_real_option('in:detect_disulf_tolerance'))
rosetta.basic.options.set_real_option('in:detect_disulf_tolerance', 1.0)
print (rosetta.basic.options.get_real_option('in:detect_disulf_tolerance'))
pose0.conformation().detect_disulfides()
# Converto to all atom representation
switch = SwitchResidueTypeSetMover("fa_standard")
switch.apply(pose0)
# idealize problematic local regions if exists
idealize = rosetta.protocols.idealize.IdealizeMover()
poslist = rosetta.utility.vector1_unsigned_long()
scorefxn=create_score_function('empty')
scorefxn.set_weight(rosetta.core.scoring.cart_bonded, 1.0)
scorefxn.score(pose0)
emap = pose0.energies()
print("idealize...")
for res in range(1,L+1):
cart = emap.residue_total_energy(res)
if cart > 50:
poslist.append(res)
print( "idealize %d %8.3f"%(res,cart) )
if len(poslist) > 0:
idealize.set_pos_list(poslist)
try:
idealize.apply(pose0)
except:
print('!!! idealization failed !!!')
# Save checkpoint
if args.save_chk:
pose0.dump_pdb("%s_before_relax.pdb"%'.'.join(args.OUT.split('.')[:-1]))
else: # checkpoint exists
pose0 = pose_from_file("%s_before_relax.pdb"%('.'.join(args.OUT.split('.')[:-1])))
#
print ("to centroid")
switch_cen = SwitchResidueTypeSetMover("centroid")
switch_cen.apply(pose0)
#
print ("detect disulfide bonds")
# Set options for disulfide tolerance -> 1.0A
print (rosetta.basic.options.get_real_option('in:detect_disulf_tolerance'))
rosetta.basic.options.set_real_option('in:detect_disulf_tolerance', 1.0)
print (rosetta.basic.options.get_real_option('in:detect_disulf_tolerance'))
pose0.conformation().detect_disulfides()
#
print ("to all atom")
switch = SwitchResidueTypeSetMover("fa_standard")
switch.apply(pose0)
########################################################
# full-atom refinement
########################################################
if args.fastrelax == True:
mmap = MoveMap()
mmap.set_bb(True)
mmap.set_chi(True)
mmap.set_jump(True)
# First round: Repeat 2 torsion space relax w/ strong disto/anglogram constraints
sf_fa_round1 = create_score_function('ref2015_cart')
sf_fa_round1.set_weight(rosetta.core.scoring.atom_pair_constraint, 3.0)
sf_fa_round1.set_weight(rosetta.core.scoring.dihedral_constraint, 1.0)
sf_fa_round1.set_weight(rosetta.core.scoring.angle_constraint, 1.0)
sf_fa_round1.set_weight(rosetta.core.scoring.pro_close, 0.0)
relax_round1 = rosetta.protocols.relax.FastRelax(sf_fa_round1, "%s/data/relax_round1.txt"%scriptdir)
relax_round1.set_movemap(mmap)
print('relax: First round... (focused on torsion space relaxation)')
params['PCUT'] = 0.15
pose0.remove_constraints()
add_rst(pose0, rst, 3, len(seq), params, nogly=True, use_orient=True)
relax_round1.apply(pose0)
# Set options for disulfide tolerance -> 0.5A
print (rosetta.basic.options.get_real_option('in:detect_disulf_tolerance'))
rosetta.basic.options.set_real_option('in:detect_disulf_tolerance', 0.5)
print (rosetta.basic.options.get_real_option('in:detect_disulf_tolerance'))
sf_fa = create_score_function('ref2015_cart')
sf_fa.set_weight(rosetta.core.scoring.atom_pair_constraint, 0.1)
sf_fa.set_weight(rosetta.core.scoring.dihedral_constraint, 0.0)
sf_fa.set_weight(rosetta.core.scoring.angle_constraint, 0.0)
relax_round2 = rosetta.protocols.relax.FastRelax(sf_fa, "%s/data/relax_round2.txt"%scriptdir)
relax_round2.set_movemap(mmap)
relax_round2.cartesian(True)
relax_round2.dualspace(True)
print('relax: Second round... (cartesian space)')
params['PCUT'] = 0.30 # To reduce the number of pair restraints..
pose0.remove_constraints()
pose0.conformation().detect_disulfides() # detect disulfide bond again w/ stricter cutoffs
# To reduce the number of constraints, only pair distances are considered w/ higher prob cutoffs
add_rst(pose0, rst, 3, len(seq), params, nogly=True, use_orient=False, p12_cut=params['PCUT'])
# Instead, apply CA coordinate constraints to prevent drifting away too much (focus on local refinement?)
add_crd_rst(pose0, L, std=1.0, tol=2.0)
relax_round2.apply(pose0)
# Re-evaluate score w/o any constraints
scorefxn_min=create_score_function('ref2015_cart')
scorefxn_min.score(pose0)
########################################################
# save final model
########################################################
pose0.dump_pdb(args.OUT)
if __name__ == '__main__':
main()

View file

@ -0,0 +1,38 @@
import argparse
def get_args(params):
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("NPZ", type=str, help="input distograms and anglegrams (NN predictions)")
parser.add_argument("FASTA", type=str, help="input sequence")
parser.add_argument("OUT", type=str, help="output model (in PDB format)")
parser.add_argument('-r', type=int, dest='nrestarts', default=params['NRUNS'], help='number of noisy restrarts')
parser.add_argument('-pd', type=float, dest='pcut', default=params['PCUT'], help='min probability of distance restraints')
parser.add_argument('-m', type=int, dest='mode', default=2, choices=[0,1,2], help='0: sh+m+l, 1: (sh+m)+l, 2: (sh+m+l)')
parser.add_argument('-bb', type=str, dest='bb', default='', help='predicted backbone torsions')
parser.add_argument('-sg', type=str, dest='sg', default='', help='window size and order for a Savitzky-Golay filter (comma-separated)')
parser.add_argument('-n', type=int, dest='steps', default=1000, help='number of minimization steps')
parser.add_argument('--save_chk', dest='save_chk', default=False, action='store_true', help='save checkpoint files to restart')
parser.add_argument('--orient', dest='use_orient', action='store_true', help='use orientations')
parser.add_argument('--no-orient', dest='use_orient', action='store_false')
parser.add_argument('--fastrelax', dest='fastrelax', action='store_true', help='perform FastRelax')
parser.add_argument('--no-fastrelax', dest='fastrelax', action='store_false')
parser.add_argument('--roll', dest='roll', action='store_true', help='circularly shift 6d coordinate arrays by 1')
parser.add_argument('--no-roll', dest='roll', action='store_false')
parser.set_defaults(use_orient=True)
parser.set_defaults(fastrelax=True)
parser.set_defaults(roll=False)
args = parser.parse_args()
params['PCUT'] = args.pcut
params['USE_ORIENT'] = args.use_orient
params['NRUNS'] = args.nrestarts
params['ROLL'] = args.roll
params['NPZ'] = args.NPZ
params['BB'] = args.bb
params['SG'] = args.sg
return args

View file

@ -0,0 +1,17 @@
{
"PCUT" : 0.05,
"PCUT1" : 0.5,
"EBASE" : -0.5,
"EREP" : [10.0,3.0,0.5],
"DREP" : [0.0,2.0,3.5],
"PREP" : 0.1,
"SIGD" : 10.0,
"SIGM" : 1.0,
"MEFF" : 0.0001,
"DCUT" : 19.5,
"ALPHA" : 1.57,
"DSTEP" : 0.5,
"ASTEP" : 15.0,
"BBWGHT" : 10.0,
"NRUNS" : 3
}

View file

@ -0,0 +1,17 @@
switch:torsion
repeat 2
ramp_repack_min 0.02 0.01 1.0 100
ramp_repack_min 0.250 0.01 0.5 100
ramp_repack_min 0.550 0.01 0.1 100
ramp_repack_min 1 0.00001 0.1 100
accept_to_best
endrepeat
switch:cartesian
repeat 1
ramp_repack_min 0.02 0.01 1.0 50
ramp_repack_min 0.250 0.01 0.5 50
ramp_repack_min 0.550 0.01 0.1 100
ramp_repack_min 1 0.00001 0.1 200
accept_to_best
endrepeat

View file

@ -0,0 +1,8 @@
switch:cartesian
repeat 2
ramp_repack_min 0.02 0.01 1.0 50
ramp_repack_min 0.250 0.01 0.5 50
ramp_repack_min 0.550 0.01 0.1 100
ramp_repack_min 1 0.00001 0.1 200
accept_to_best
endrepeat

View file

@ -0,0 +1,7 @@
cen_hb 5.0
rama 1.0
omega 0.5
vdw 1.0
atom_pair_constraint 5.0
dihedral_constraint 4.0
angle_constraint 4.0

View file

@ -0,0 +1,7 @@
cen_hb 5.0
rama 1.0
omega 0.5
vdw 5.0
atom_pair_constraint 1.0
dihedral_constraint 0.5
angle_constraint 0.5

View file

@ -0,0 +1,9 @@
hbond_sr_bb 3.0
hbond_lr_bb 3.0
rama 1.0
omega 0.5
vdw 5.0
cart_bonded 5.0
atom_pair_constraint 1.0
dihedral_constraint 0.5
angle_constraint 0.5

View file

@ -0,0 +1,2 @@
rama 1.0
vdw 1.0

View file

@ -0,0 +1,373 @@
import numpy as np
import random
import scipy
from scipy.signal import *
from pyrosetta import *
eps = 1e-9
P_ADD_OMEGA = 0.5
P_ADD_THETA = 0.5
P_ADD_PHI = 0.6
def gen_rst(params):
npz = np.load(params['NPZ'])
dist,omega,theta,phi = npz['dist'],npz['omega'],npz['theta'],npz['phi']
if params['ROLL']==True:
print("Apply circular shift...")
dist = np.roll(dist,1,axis=-1)
omega = np.roll(omega,1,axis=-1)
theta = np.roll(theta,1,axis=-1)
phi = np.roll(phi,1,axis=-1)
dist = dist.astype(np.float32) + eps
omega = omega.astype(np.float32) + eps
theta = theta.astype(np.float32) + eps
phi = phi.astype(np.float32) + eps
# dictionary to store Rosetta restraints
rst = {'dist' : [], 'omega' : [], 'theta' : [], 'phi' : []}
########################################################
# assign parameters
########################################################
PCUT = 0.05 #params['PCUT']
EBASE = params['EBASE']
EREP = params['EREP']
DREP = params['DREP']
PREP = params['PREP']
SIGD = params['SIGD']
SIGM = params['SIGM']
MEFF = params['MEFF']
DCUT = params['DCUT']
ALPHA = params['ALPHA']
BBWGHT = params['BBWGHT']
DSTEP = params['DSTEP']
ASTEP = np.deg2rad(params['ASTEP'])
seq = params['seq']
sg_flag = False
if params['SG'] != '':
sg_flag = True
sg_w,sg_n = [int(v) for v in params['SG'].split(",")]
print("Savitzky-Golay: %d,%d"%(sg_w,sg_n))
########################################################
# dist: 0..20A
########################################################
nres = dist.shape[0]
bins = np.array([4.25+DSTEP*i for i in range(32)])
prob = np.sum(dist[:,:,5:], axis=-1) # prob of dist within 20A
prob_12 = np.sum(dist[:,:,5:21], axis=-1) # prob of dist within 12A
bkgr = np.array((bins/DCUT)**ALPHA)
attr = -np.log((dist[:,:,5:]+MEFF)/(dist[:,:,-1][:,:,None]*bkgr[None,None,:]))+EBASE
repul = np.maximum(attr[:,:,0],np.zeros((nres,nres)))[:,:,None]+np.array(EREP)[None,None,:]
dist = np.concatenate([repul,attr], axis=-1)
bins = np.concatenate([DREP,bins])
x = pyrosetta.rosetta.utility.vector1_double()
_ = [x.append(v) for v in bins]
#
prob = np.triu(prob, k=1) # fill zeros to diagonal and lower (for speed-up)
i,j = np.where(prob>PCUT)
prob = prob[i,j]
prob_12 = prob_12[i,j]
#nbins = 35
step = 0.5
for a,b,p,p_12 in zip(i,j,prob,prob_12):
y = pyrosetta.rosetta.utility.vector1_double()
if sg_flag == True:
_ = [y.append(v) for v in savgol_filter(dist[a,b],sg_w,sg_n)]
else:
_ = [y.append(v) for v in dist[a,b]]
spline = rosetta.core.scoring.func.SplineFunc("", 1.0, 0.0, step, x,y)
ida = rosetta.core.id.AtomID(5,a+1)
idb = rosetta.core.id.AtomID(5,b+1)
rst['dist'].append([a,b,p,p_12,rosetta.core.scoring.constraints.AtomPairConstraint(ida, idb, spline)])
print("dist restraints: %d"%(len(rst['dist'])))
########################################################
# omega: -pi..pi
########################################################
nbins = omega.shape[2]-1
ASTEP = 2.0*np.pi/nbins
nbins += 4
bins = np.linspace(-np.pi-1.5*ASTEP, np.pi+1.5*ASTEP, nbins)
x = pyrosetta.rosetta.utility.vector1_double()
_ = [x.append(v) for v in bins]
prob = np.sum(omega[:,:,1:], axis=-1)
prob = np.triu(prob, k=1) # fill zeros to diagonal and lower (for speed-up)
i,j = np.where(prob>PCUT+P_ADD_OMEGA)
prob = prob[i,j]
omega = -np.log((omega+MEFF)/(omega[:,:,-1]+MEFF)[:,:,None])
#if sg_flag == True:
# omega = savgol_filter(omega,sg_w,sg_n,axis=-1,mode='wrap')
omega = np.concatenate([omega[:,:,-2:],omega[:,:,1:],omega[:,:,1:3]],axis=-1)
for a,b,p in zip(i,j,prob):
y = pyrosetta.rosetta.utility.vector1_double()
_ = [y.append(v) for v in omega[a,b]]
spline = rosetta.core.scoring.func.SplineFunc("", 1.0, 0.0, ASTEP, x,y)
id1 = rosetta.core.id.AtomID(2,a+1) # CA-i
id2 = rosetta.core.id.AtomID(5,a+1) # CB-i
id3 = rosetta.core.id.AtomID(5,b+1) # CB-j
id4 = rosetta.core.id.AtomID(2,b+1) # CA-j
rst['omega'].append([a,b,p,rosetta.core.scoring.constraints.DihedralConstraint(id1,id2,id3,id4, spline)])
print("omega restraints: %d"%(len(rst['omega'])))
########################################################
# theta: -pi..pi
########################################################
prob = np.sum(theta[:,:,1:], axis=-1)
np.fill_diagonal(prob, 0.0)
i,j = np.where(prob>PCUT+P_ADD_THETA)
prob = prob[i,j]
theta = -np.log((theta+MEFF)/(theta[:,:,-1]+MEFF)[:,:,None])
#if sg_flag == True:
# theta = savgol_filter(theta,sg_w,sg_n,axis=-1,mode='wrap')
theta = np.concatenate([theta[:,:,-2:],theta[:,:,1:],theta[:,:,1:3]],axis=-1)
for a,b,p in zip(i,j,prob):
y = pyrosetta.rosetta.utility.vector1_double()
_ = [y.append(v) for v in theta[a,b]]
spline = rosetta.core.scoring.func.SplineFunc("", 1.0, 0.0, ASTEP, x,y)
id1 = rosetta.core.id.AtomID(1,a+1) # N-i
id2 = rosetta.core.id.AtomID(2,a+1) # CA-i
id3 = rosetta.core.id.AtomID(5,a+1) # CB-i
id4 = rosetta.core.id.AtomID(5,b+1) # CB-j
rst['theta'].append([a,b,p,rosetta.core.scoring.constraints.DihedralConstraint(id1,id2,id3,id4, spline)])
print("theta restraints: %d"%(len(rst['theta'])))
########################################################
# phi: 0..pi
########################################################
nbins = phi.shape[2]-1+4
bins = np.linspace(-1.5*ASTEP, np.pi+1.5*ASTEP, nbins)
x = pyrosetta.rosetta.utility.vector1_double()
_ = [x.append(v) for v in bins]
prob = np.sum(phi[:,:,1:], axis=-1)
np.fill_diagonal(prob, 0.0)
i,j = np.where(prob>PCUT+P_ADD_PHI)
prob = prob[i,j]
phi = -np.log((phi+MEFF)/(phi[:,:,-1]+MEFF)[:,:,None])
#if sg_flag == True:
# phi = savgol_filter(phi,sg_w,sg_n,axis=-1,mode='mirror')
phi = np.concatenate([np.flip(phi[:,:,1:3],axis=-1),phi[:,:,1:],np.flip(phi[:,:,-2:],axis=-1)], axis=-1)
for a,b,p in zip(i,j,prob):
y = pyrosetta.rosetta.utility.vector1_double()
_ = [y.append(v) for v in phi[a,b]]
spline = rosetta.core.scoring.func.SplineFunc("", 1.0, 0.0, ASTEP, x,y)
id1 = rosetta.core.id.AtomID(2,a+1) # CA-i
id2 = rosetta.core.id.AtomID(5,a+1) # CB-i
id3 = rosetta.core.id.AtomID(5,b+1) # CB-j
rst['phi'].append([a,b,p,rosetta.core.scoring.constraints.AngleConstraint(id1,id2,id3, spline)])
print("phi restraints: %d"%(len(rst['phi'])))
########################################################
# backbone torsions
########################################################
if (params['BB'] != ''):
bbnpz = np.load(params['BB'])
bbphi,bbpsi = bbnpz['phi'],bbnpz['psi']
rst['bbphi'] = []
rst['bbpsi'] = []
nbins = bbphi.shape[1]+4
step = 2.*np.pi/bbphi.shape[1]
bins = np.linspace(-1.5*step-np.pi, np.pi+1.5*step, nbins)
x = pyrosetta.rosetta.utility.vector1_double()
_ = [x.append(v) for v in bins]
bbphi = -np.log(bbphi)
bbphi = np.concatenate([bbphi[:,-2:],bbphi,bbphi[:,:2]],axis=-1).copy()
bbpsi = -np.log(bbpsi)
bbpsi = np.concatenate([bbpsi[:,-2:],bbpsi,bbpsi[:,:2]],axis=-1).copy()
for i in range(1,nres):
N1 = rosetta.core.id.AtomID(1,i)
Ca1 = rosetta.core.id.AtomID(2,i)
C1 = rosetta.core.id.AtomID(3,i)
N2 = rosetta.core.id.AtomID(1,i+1)
Ca2 = rosetta.core.id.AtomID(2,i+1)
C2 = rosetta.core.id.AtomID(3,i+1)
# psi(i)
ypsi = pyrosetta.rosetta.utility.vector1_double()
_ = [ypsi.append(v) for v in bbpsi[i-1]]
spsi = rosetta.core.scoring.func.SplineFunc("", BBWGHT, 0.0, step, x,ypsi)
rst['bbpsi'].append(rosetta.core.scoring.constraints.DihedralConstraint(N1,Ca1,C1,N2, spsi))
# phi(i+1)
yphi = pyrosetta.rosetta.utility.vector1_double()
_ = [yphi.append(v) for v in bbphi[i]]
sphi = rosetta.core.scoring.func.SplineFunc("", BBWGHT, 0.0, step, x,yphi)
rst['bbphi'].append(rosetta.core.scoring.constraints.DihedralConstraint(C1,N2,Ca2,C2, sphi))
print("bbbtor restraints: %d"%(len(rst['bbphi'])+len(rst['bbpsi'])))
return rst
def set_predicted_dihedral(pose, phi, psi, omega):
nbins = phi.shape[1]
bins = np.linspace(-180.,180.,nbins+1)[:-1] + 180./nbins
nres = pose.total_residue()
for i in range(nres):
pose.set_phi(i+1,np.random.choice(bins,p=phi[i]))
pose.set_psi(i+1,np.random.choice(bins,p=psi[i]))
if np.random.uniform() < omega[i,0]:
pose.set_omega(i+1,0)
else:
pose.set_omega(i+1,180)
def set_random_dihedral(pose):
nres = pose.total_residue()
for i in range(1, nres+1):
phi,psi=random_dihedral()
pose.set_phi(i,phi)
pose.set_psi(i,psi)
pose.set_omega(i,180)
return(pose)
#pick phi/psi randomly from:
#-140 153 180 0.135 B
# -72 145 180 0.155 B
#-122 117 180 0.073 B
# -82 -14 180 0.122 A
# -61 -41 180 0.497 A
# 57 39 180 0.018 L
def random_dihedral():
phi=0
psi=0
r=random.random()
if(r<=0.135):
phi=-140
psi=153
elif(r>0.135 and r<=0.29):
phi=-72
psi=145
elif(r>0.29 and r<=0.363):
phi=-122
psi=117
elif(r>0.363 and r<=0.485):
phi=-82
psi=-14
elif(r>0.485 and r<=0.982):
phi=-61
psi=-41
else:
phi=57
psi=39
return(phi, psi)
def read_fasta(file):
fasta=""
first = True
with open(file, "r") as f:
for line in f:
if(line[0] == ">"):
if first:
first = False
continue
else:
break
else:
line=line.rstrip()
fasta = fasta + line;
return fasta
def remove_clash(scorefxn, mover, pose):
for _ in range(0, 5):
if float(scorefxn(pose)) < 10:
break
mover.apply(pose)
def add_rst(pose, rst, sep1, sep2, params, nogly=False, use_orient=None, pcut=None, p12_cut=0.0):
if use_orient == None:
use_orient = params['USE_ORIENT']
if pcut == None:
pcut=params['PCUT']
seq = params['seq']
# collect restraints
array = []
if nogly==True:
dist_r = [r for a,b,p,p_12,r in rst['dist'] if abs(a-b)>=sep1 and abs(a-b)<sep2 and seq[a]!='G' and seq[b]!='G' and p>=pcut and p_12>=p12_cut]
if use_orient:
omega_r = [r for a,b,p,r in rst['omega'] if abs(a-b)>=sep1 and abs(a-b)<sep2 and seq[a]!='G' and seq[b]!='G' and p>=pcut+P_ADD_OMEGA] #0.5
theta_r = [r for a,b,p,r in rst['theta'] if abs(a-b)>=sep1 and abs(a-b)<sep2 and seq[a]!='G' and seq[b]!='G' and p>=pcut+P_ADD_THETA] #0.5
phi_r = [r for a,b,p,r in rst['phi'] if abs(a-b)>=sep1 and abs(a-b)<sep2 and seq[a]!='G' and seq[b]!='G' and p>=pcut+P_ADD_PHI] #0.6
else:
dist_r = [r for a,b,p,p_12,r in rst['dist'] if abs(a-b)>=sep1 and abs(a-b)<sep2 and p>=pcut and p_12>=p12_cut]
if use_orient:
omega_r = [r for a,b,p,r in rst['omega'] if abs(a-b)>=sep1 and abs(a-b)<sep2 and p>=pcut+P_ADD_OMEGA]
theta_r = [r for a,b,p,r in rst['theta'] if abs(a-b)>=sep1 and abs(a-b)<sep2 and p>=pcut+P_ADD_THETA]
phi_r = [r for a,b,p,r in rst['phi'] if abs(a-b)>=sep1 and abs(a-b)<sep2 and p>=pcut+P_ADD_PHI] #0.6
#if params['BB'] != '':
# array += [r for r in rst['bbphi']]
# array += [r for r in rst['bbpsi']]
array += dist_r
if use_orient:
array += omega_r
array += theta_r
array += phi_r
if len(array) < 1:
return
print ("Number of applied pair restraints:", len(array))
print (" - Distance restraints:", len(dist_r))
if use_orient:
print (" - Omega restraints:", len(omega_r))
print (" - Theta restraints:", len(theta_r))
print (" - Phi restraints: ", len(phi_r))
#random.shuffle(array)
cset = rosetta.core.scoring.constraints.ConstraintSet()
[cset.add_constraint(a) for a in array]
# add to pose
constraints = rosetta.protocols.constraint_movers.ConstraintSetMover()
constraints.constraint_set(cset)
constraints.add_constraints(True)
constraints.apply(pose)
def add_crd_rst(pose, nres, std=1.0, tol=1.0):
flat_har = rosetta.core.scoring.func.FlatHarmonicFunc(0.0, std, tol)
rst = list()
for i in range(1, nres+1):
xyz = pose.residue(i).atom("CA").xyz() # xyz coord of CA atom
ida = rosetta.core.id.AtomID(2,i) # CA idx for residue i
rst.append(rosetta.core.scoring.constraints.CoordinateConstraint(ida, ida, xyz, flat_har))
if len(rst) < 1:
return
print ("Number of applied coordinate restraints:", len(rst))
#random.shuffle(rst)
cset = rosetta.core.scoring.constraints.ConstraintSet()
[cset.add_constraint(a) for a in rst]
# add to pose
constraints = rosetta.protocols.constraint_movers.ConstraintSetMover()
constraints.constraint_set(cset)
constraints.add_constraints(True)
constraints.apply(pose)

View file

@ -0,0 +1 @@
This work was supported by Microsoft (MB, DB, and generous gifts of Azure compute time and expertise), Eric and Wendy Schmidt by recommendation of the Schmidt Futures program (FD, HP), Open Philanthropy (DB, GRL), The Washington Research Foundation (MB, GRL, JW), National Science Foundation Cyberinfrastructure for Biological Research, Award # DBI 1937533 (IA), Wellcome Trust, grant number 209407/Z/17/Z (RJR), National Institute of Health, grant numbers P01GM063210 (PDA, RJR), DP5OD026389 (SO), Global Challenges Research Fund (GCRF) through Science & Technology Facilities Council (STFC), grant number ST/R002754/1: Synchrotron Techniques for African Research and Technology (START) (DJO), Austrian Science Fund (FWF) projects P29432 and DOC50 (doc.fund Molecular Metabolism) (TS, CB, TP).

Binary file not shown.

After

Width:  |  Height:  |  Size: 115 KiB

View file

@ -0,0 +1,57 @@
#!/bin/bash
# inputs
in_fasta="$1"
out_dir="$2"
# resources
CPU="$3"
MEM="$4"
# sequence database
DB="$5/UniRef30_2020_06/UniRef30_2020_06"
# setup hhblits command
HHBLITS="hhblits -o /dev/null -mact 0.35 -maxfilt 100000000 -neffmax 20 -cov 25 -cpu $CPU -nodiff -realign_max 100000000 -maxseq 1000000 -maxmem $MEM -n 4 -d $DB"
echo $HHBLITS
mkdir -p $out_dir/hhblits
tmp_dir="$out_dir/hhblits"
out_prefix="$out_dir/t000_"
# perform iterative searches
prev_a3m="$in_fasta"
for e in 1e-30 1e-10 1e-6 1e-3
do
echo $e
$HHBLITS -i $prev_a3m -oa3m $tmp_dir/t000_.$e.a3m -e $e -v 0
hhfilter -id 90 -cov 75 -i $tmp_dir/t000_.$e.a3m -o $tmp_dir/t000_.$e.id90cov75.a3m
hhfilter -id 90 -cov 50 -i $tmp_dir/t000_.$e.a3m -o $tmp_dir/t000_.$e.id90cov50.a3m
prev_a3m="$tmp_dir/t000_.$e.id90cov50.a3m"
n75=`grep -c "^>" $tmp_dir/t000_.$e.id90cov75.a3m`
n50=`grep -c "^>" $tmp_dir/t000_.$e.id90cov50.a3m`
if ((n75>2000))
then
if [ ! -s ${out_prefix}.msa0.a3m ]
then
cp $tmp_dir/t000_.$e.id90cov75.a3m ${out_prefix}.msa0.a3m
break
fi
elif ((n50>4000))
then
if [ ! -s ${out_prefix}.msa0.a3m ]
then
cp $tmp_dir/t000_.$e.id90cov50.a3m ${out_prefix}.msa0.a3m
break
fi
else
continue
fi
done
if [ ! -s ${out_prefix}.msa0.a3m ]
then
cp $tmp_dir/t000_.1e-3.id90cov50.a3m ${out_prefix}.msa0.a3m
fi

View file

@ -0,0 +1,29 @@
#!/bin/bash
DATADIR="/workspace/psipred/data/"
echo $DATADIR
i_a3m="$1"
o_ss="$2"
ID=$(basename $i_a3m .a3m).tmp
/workspace/csblast-2.2.3/bin/csbuild -i $i_a3m -I a3m -D /workspace/csblast-2.2.3/data/K4000.crf -o $ID.chk -O chk
head -n 2 $i_a3m > $ID.fasta
echo $ID.chk > $ID.pn
echo $ID.fasta > $ID.sn
/workspace/blast-2.2.26/bin/makemat -P $ID
/workspace/psipred/bin/psipred $ID.mtx $DATADIR/weights.dat $DATADIR/weights.dat2 $DATADIR/weights.dat3 > $ID.ss
/workspace/psipred/bin/psipass2 $DATADIR/weights_p2.dat 1 1.0 1.0 $i_a3m.csb.hhblits.ss2 $ID.ss > $ID.horiz
(
echo ">ss_pred"
grep "^Pred" $ID.horiz | awk '{print $2}'
echo ">ss_conf"
grep "^Conf" $ID.horiz | awk '{print $2}'
) | awk '{if(substr($1,1,1)==">") {print "\n"$1} else {printf "%s", $1}} END {print ""}' | sed "1d" > $o_ss
rm ${i_a3m}.csb.hhblits.ss2
rm $ID.*

View file

@ -0,0 +1,18 @@
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
WDIR=$1
CPU=$2
MEM=$3
DB="$4/pdb100_2021Mar03/pdb100_2021Mar03"
HH="hhsearch -b 50 -B 500 -z 50 -Z 500 -mact 0.05 -cpu $CPU -maxmem $MEM -aliw 100000 -e 100 -p 5.0 -d $DB"
cat $WDIR/t000_.ss2 $WDIR/t000_.msa0.a3m > $WDIR/t000_.msa0.ss2.a3m
$HH -i $WDIR/t000_.msa0.ss2.a3m -o $WDIR/t000_.hhr -atab $WDIR/t000_.atab -v 0

View file

@ -0,0 +1,27 @@
#!/bin/bash
# install external program not supported by conda installation
case "$(uname -s)" in
Linux*) platform=linux;;
Darwin*) platform=macosx;;
*) echo "unsupported OS type. exiting"; exit 1
esac
echo "installing for ${platform}"
# download lddt
echo "downloading lddt . . ."
wget https://openstructure.org/static/lddt-${platform}.zip -O lddt.zip
unzip -d lddt -j lddt.zip
# the cs-blast platform descriptoin includes the width of memory addresses
# we expect a 64-bit operating system
if [[ ${platform} == "linux" ]]; then
platform=${platform}64
fi
# download cs-blast
echo "downloading cs-blast . . ."
wget http://wwwuser.gwdg.de/~compbiol/data/csblast/releases/csblast-2.2.3_${platform}.tar.gz -O csblast-2.2.3.tar.gz
mkdir -p csblast-2.2.3
tar xf csblast-2.2.3.tar.gz -C csblast-2.2.3 --strip-components=1

View file

@ -0,0 +1,480 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from Transformer import *
from Transformer import _get_clones
import Transformer
from resnet import ResidualNetwork
from SE3_network import SE3Transformer
from InitStrGenerator import InitStr_Network
import dgl
# Attention module based on AlphaFold2's idea written by Minkyung Baek
# - Iterative MSA feature extraction
# - 1) MSA2Pair: extract pairwise feature from MSA --> added to previous residue-pair features
# architecture design inspired by CopulaNet paper
# - 2) MSA2MSA: process MSA features using Transformer (or Performer) encoder. (Attention over L first followed by attention over N)
# - 3) Pair2MSA: Update MSA features using pair feature
# - 4) Pair2Pair: process pair features using Transformer (or Performer) encoder.
def make_graph(xyz, pair, idx, top_k=64, kmin=9):
'''
Input:
- xyz: current backbone cooordinates (B, L, 3, 3)
- pair: pair features from Trunk (B, L, L, E)
- idx: residue index from ground truth pdb
Output:
- G: defined graph
'''
B, L = xyz.shape[:2]
device = xyz.device
# distance map from current CA coordinates
D = torch.cdist(xyz[:,:,1,:], xyz[:,:,1,:]) + torch.eye(L, device=device).unsqueeze(0)*999.9 # (B, L, L)
# seq sep
sep = idx[:,None,:] - idx[:,:,None]
sep = sep.abs() + torch.eye(L, device=device).unsqueeze(0)*999.9
# get top_k neighbors
D_neigh, E_idx = torch.topk(D, min(top_k, L), largest=False) # shape of E_idx: (B, L, top_k)
topk_matrix = torch.zeros((B, L, L), device=device)
topk_matrix.scatter_(2, E_idx, 1.0)
# put an edge if any of the 3 conditions are met:
# 1) |i-j| <= kmin (connect sequentially adjacent residues)
# 2) top_k neighbors
cond = torch.logical_or(topk_matrix > 0.0, sep < kmin)
b,i,j = torch.where(cond)
src = b*L+i
tgt = b*L+j
G = dgl.graph((src, tgt), num_nodes=B*L).to(device)
G.edata['d'] = (xyz[b,j,1,:] - xyz[b,i,1,:]).detach() # no gradient through basis function
G.edata['w'] = pair[b,i,j]
return G
def get_bonded_neigh(idx):
'''
Input:
- idx: residue indices of given sequence (B,L)
Output:
- neighbor: bonded neighbor information with sign (B, L, L, 1)
'''
neighbor = idx[:,None,:] - idx[:,:,None]
neighbor = neighbor.float()
sign = torch.sign(neighbor) # (B, L, L)
neighbor = torch.abs(neighbor)
neighbor[neighbor > 1] = 0.0
neighbor = sign * neighbor
return neighbor.unsqueeze(-1)
def rbf(D):
# Distance radial basis function
D_min, D_max, D_count = 0., 20., 36
D_mu = torch.linspace(D_min, D_max, D_count).to(D.device)
D_mu = D_mu[None,:]
D_sigma = (D_max - D_min) / D_count
D_expand = torch.unsqueeze(D, -1)
RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
return RBF
class CoevolExtractor(nn.Module):
def __init__(self, n_feat_proj, n_feat_out, p_drop=0.1):
super(CoevolExtractor, self).__init__()
self.norm_2d = LayerNorm(n_feat_proj*n_feat_proj)
# project down to output dimension (pair feature dimension)
self.proj_2 = nn.Linear(n_feat_proj**2, n_feat_out)
def forward(self, x_down, x_down_w):
B, N, L = x_down.shape[:3]
pair = torch.einsum('abij,ablm->ailjm', x_down, x_down_w) # outer-product & average pool
pair = pair.reshape(B, L, L, -1)
pair = self.norm_2d(pair)
pair = self.proj_2(pair) # (B, L, L, n_feat_out) # project down to pair dimension
return pair
class MSA2Pair(nn.Module):
def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32,
n_resblock=1, p_drop=0.1, n_att_head=8):
super(MSA2Pair, self).__init__()
# project down embedding dimension (n_feat --> n_feat_proj)
self.norm_1 = LayerNorm(n_feat)
self.proj_1 = nn.Linear(n_feat, n_feat_proj)
self.encoder = SequenceWeight(n_feat_proj, 1, dropout=p_drop)
self.coevol = CoevolExtractor(n_feat_proj, n_feat_out)
# ResNet to update pair features
self.norm_down = LayerNorm(n_feat_proj)
self.norm_orig = LayerNorm(n_feat_out)
self.norm_new = LayerNorm(n_feat_out)
self.update = ResidualNetwork(n_resblock, n_feat_out*2+n_feat_proj*4+n_att_head, n_feat_out, n_feat_out, p_drop=p_drop)
def forward(self, msa, pair_orig, att):
# Input: MSA embeddings (B, N, L, K), original pair embeddings (B, L, L, C)
# Output: updated pair info (B, L, L, C)
B, N, L, _ = msa.shape
# project down to reduce memory
msa = self.norm_1(msa)
x_down = self.proj_1(msa) # (B, N, L, n_feat_proj)
# get sequence weight
x_down = self.norm_down(x_down)
w_seq = self.encoder(x_down).reshape(B, L, 1, N).permute(0,3,1,2)
feat_1d = w_seq*x_down
pair = self.coevol(x_down, feat_1d)
# average pooling over N of given MSA info
feat_1d = feat_1d.sum(1)
# query sequence info
query = x_down[:,0] # (B,L,K)
feat_1d = torch.cat((feat_1d, query), dim=-1) # additional 1D features
# tile 1D features
left = feat_1d.unsqueeze(2).repeat(1, 1, L, 1)
right = feat_1d.unsqueeze(1).repeat(1, L, 1, 1)
# update original pair features through convolutions after concat
pair_orig = self.norm_orig(pair_orig)
pair = self.norm_new(pair)
pair = torch.cat((pair_orig, pair, left, right, att), -1)
pair = pair.permute(0,3,1,2).contiguous() # prep for convolution layer
pair = self.update(pair)
pair = pair.permute(0,2,3,1).contiguous() # (B, L, L, C)
return pair
class MSA2MSA(nn.Module):
def __init__(self, n_layer=1, n_att_head=8, n_feat=256, r_ff=4, p_drop=0.1,
performer_N_opts=None, performer_L_opts=None):
super(MSA2MSA, self).__init__()
# attention along L
enc_layer_1 = EncoderLayer(d_model=n_feat, d_ff=n_feat*r_ff,
heads=n_att_head, p_drop=p_drop,
use_tied=True)
#performer_opts=performer_L_opts)
self.encoder_1 = Encoder(enc_layer_1, n_layer)
# attention along N
enc_layer_2 = EncoderLayer(d_model=n_feat, d_ff=n_feat*r_ff,
heads=n_att_head, p_drop=p_drop,
performer_opts=performer_N_opts)
self.encoder_2 = Encoder(enc_layer_2, n_layer)
def forward(self, x):
# Input: MSA embeddings (B, N, L, K)
# Output: updated MSA embeddings (B, N, L, K)
B, N, L, _ = x.shape
# attention along L
x, att = self.encoder_1(x, return_att=True)
# attention along N
x = x.permute(0,2,1,3).contiguous()
x = self.encoder_2(x)
x = x.permute(0,2,1,3).contiguous()
return x, att
class Pair2MSA(nn.Module):
def __init__(self, n_layer=1, n_att_head=4, n_feat_in=128, n_feat_out=256, r_ff=4, p_drop=0.1):
super(Pair2MSA, self).__init__()
enc_layer = DirectEncoderLayer(heads=n_att_head, \
d_in=n_feat_in, d_out=n_feat_out,\
d_ff=n_feat_out*r_ff,\
p_drop=p_drop)
self.encoder = CrossEncoder(enc_layer, n_layer)
def forward(self, pair, msa):
out = self.encoder(pair, msa) # (B, N, L, K)
return out
class Pair2Pair(nn.Module):
def __init__(self, n_layer=1, n_att_head=8, n_feat=128, r_ff=4, p_drop=0.1,
performer_L_opts=None):
super(Pair2Pair, self).__init__()
enc_layer = AxialEncoderLayer(d_model=n_feat, d_ff=n_feat*r_ff,
heads=n_att_head, p_drop=p_drop,
performer_opts=performer_L_opts)
self.encoder = Encoder(enc_layer, n_layer)
def forward(self, x):
return self.encoder(x)
class Str2Str(nn.Module):
def __init__(self, d_msa=64, d_pair=128,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, p_drop=0.1):
super(Str2Str, self).__init__()
# initial node & pair feature process
self.norm_msa = LayerNorm(d_msa)
self.norm_pair = LayerNorm(d_pair)
self.encoder_seq = SequenceWeight(d_msa, 1, dropout=p_drop)
self.embed_x = nn.Linear(d_msa+21, SE3_param['l0_in_features'])
self.embed_e = nn.Linear(d_pair, SE3_param['num_edge_features'])
self.norm_node = LayerNorm(SE3_param['l0_in_features'])
self.norm_edge = LayerNorm(SE3_param['num_edge_features'])
self.se3 = SE3Transformer(**SE3_param)
@torch.cuda.amp.autocast(enabled=False)
def forward(self, msa, pair, xyz, seq1hot, idx, top_k=64):
# process msa & pair features
B, N, L = msa.shape[:3]
msa = self.norm_msa(msa)
pair = self.norm_pair(pair)
w_seq = self.encoder_seq(msa).reshape(B, L, 1, N).permute(0,3,1,2)
msa = w_seq*msa
msa = msa.sum(dim=1)
msa = torch.cat((msa, seq1hot), dim=-1)
msa = self.norm_node(self.embed_x(msa))
pair = self.norm_edge(self.embed_e(pair))
# define graph
G = make_graph(xyz, pair, idx, top_k=top_k)
l1_feats = xyz - xyz[:,:,1,:].unsqueeze(2) # l1 features = displacement vector to CA
l1_feats = l1_feats.reshape(B*L, -1, 3)
# apply SE(3) Transformer & update coordinates
shift = self.se3(G, msa.reshape(B*L, -1, 1), l1_feats)
state = shift['0'].reshape(B, L, -1) # (B, L, C)
offset = shift['1'].reshape(B, L, -1, 3) # (B, L, 3, 3)
CA_new = xyz[:,:,1] + offset[:,:,1]
N_new = CA_new + offset[:,:,0]
C_new = CA_new + offset[:,:,2]
xyz_new = torch.stack([N_new, CA_new, C_new], dim=2)
return xyz_new, state
class Str2MSA(nn.Module):
def __init__(self, d_msa=64, d_state=32, inner_dim=32, r_ff=4,
distbin=[8.0, 12.0, 16.0, 20.0], p_drop=0.1):
super(Str2MSA, self).__init__()
self.distbin = distbin
n_att_head = len(distbin)
self.norm_state = LayerNorm(d_state)
self.norm1 = LayerNorm(d_msa)
self.attn = MaskedDirectMultiheadAttention(d_state, d_msa, n_att_head, d_k=inner_dim, dropout=p_drop)
self.dropout1 = nn.Dropout(p_drop,inplace=True)
self.norm2 = LayerNorm(d_msa)
self.ff = FeedForwardLayer(d_msa, d_msa*r_ff, p_drop=p_drop)
self.dropout2 = nn.Dropout(p_drop,inplace=True)
def forward(self, msa, xyz, state):
dist = torch.cdist(xyz[:,:,1], xyz[:,:,1]) # (B, L, L)
mask_s = list()
for distbin in self.distbin:
mask_s.append(1.0 - torch.sigmoid(dist-distbin))
mask_s = torch.stack(mask_s, dim=1) # (B, h, L, L)
state = self.norm_state(state)
msa2 = self.norm1(msa)
msa2 = self.attn(state, state, msa2, mask_s)
msa = msa + self.dropout1(msa2)
msa2 = self.norm2(msa)
msa2 = self.ff(msa2)
msa = msa + self.dropout2(msa2)
return msa
class IterBlock(nn.Module):
def __init__(self, n_layer=1, d_msa=64, d_pair=128, n_head_msa=4, n_head_pair=8, r_ff=4,
n_resblock=1, p_drop=0.1, performer_L_opts=None, performer_N_opts=None):
super(IterBlock, self).__init__()
self.msa2msa = MSA2MSA(n_layer=n_layer, n_att_head=n_head_msa, n_feat=d_msa,
r_ff=r_ff, p_drop=p_drop,
performer_N_opts=performer_N_opts,
performer_L_opts=performer_L_opts)
self.msa2pair = MSA2Pair(n_feat=d_msa, n_feat_out=d_pair, n_feat_proj=32,
n_resblock=n_resblock, p_drop=p_drop, n_att_head=n_head_msa)
self.pair2pair = Pair2Pair(n_layer=n_layer, n_att_head=n_head_pair,
n_feat=d_pair, r_ff=r_ff, p_drop=p_drop,
performer_L_opts=performer_L_opts)
self.pair2msa = Pair2MSA(n_layer=n_layer, n_att_head=4,
n_feat_in=d_pair, n_feat_out=d_msa, r_ff=r_ff, p_drop=p_drop)
def forward(self, msa, pair):
# input:
# msa: initial MSA embeddings (N, L, d_msa)
# pair: initial residue pair embeddings (L, L, d_pair)
# 1. process MSA features
msa, att = self.msa2msa(msa)
# 2. update pair features using given MSA
pair = self.msa2pair(msa, pair, att)
# 3. process pair features
pair = self.pair2pair(pair)
# 4. update MSA features using updated pair features
msa = self.pair2msa(pair, msa)
return msa, pair
class IterBlock_w_Str(nn.Module):
def __init__(self, n_layer=1, d_msa=64, d_pair=128, n_head_msa=4, n_head_pair=8, r_ff=4,
n_resblock=1, p_drop=0.1, performer_L_opts=None, performer_N_opts=None,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}):
super(IterBlock_w_Str, self).__init__()
self.msa2msa = MSA2MSA(n_layer=n_layer, n_att_head=n_head_msa, n_feat=d_msa,
r_ff=r_ff, p_drop=p_drop,
performer_N_opts=performer_N_opts,
performer_L_opts=performer_L_opts)
self.msa2pair = MSA2Pair(n_feat=d_msa, n_feat_out=d_pair, n_feat_proj=32,
n_resblock=n_resblock, p_drop=p_drop, n_att_head=n_head_msa)
self.pair2pair = Pair2Pair(n_layer=n_layer, n_att_head=n_head_pair,
n_feat=d_pair, r_ff=r_ff, p_drop=p_drop,
performer_L_opts=performer_L_opts)
self.pair2msa = Pair2MSA(n_layer=n_layer, n_att_head=4,
n_feat_in=d_pair, n_feat_out=d_msa, r_ff=r_ff, p_drop=p_drop)
self.str2str = Str2Str(d_msa=d_msa, d_pair=d_pair, SE3_param=SE3_param, p_drop=p_drop)
self.str2msa = Str2MSA(d_msa=d_msa, d_state=SE3_param['l0_out_features'],
r_ff=r_ff, p_drop=p_drop)
def forward(self, msa, pair, xyz, seq1hot, idx, top_k=64):
# input:
# msa: initial MSA embeddings (N, L, d_msa)
# pair: initial residue pair embeddings (L, L, d_pair)
# 1. process MSA features
msa, att = self.msa2msa(msa)
# 2. update pair features using given MSA
pair = self.msa2pair(msa, pair, att)
# 3. process pair features
pair = self.pair2pair(pair)
# 4. update MSA features using updated pair features
msa = self.pair2msa(pair, msa)
xyz, state = self.str2str(msa.float(), pair.float(), xyz.float(), seq1hot, idx, top_k=top_k)
msa = self.str2msa(msa, xyz, state)
return msa, pair, xyz
class FinalBlock(nn.Module):
def __init__(self, n_layer=1, d_msa=64, d_pair=128, n_head_msa=4, n_head_pair=8, r_ff=4,
n_resblock=1, p_drop=0.1, performer_L_opts=None, performer_N_opts=None,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}):
super(FinalBlock, self).__init__()
self.msa2msa = MSA2MSA(n_layer=n_layer, n_att_head=n_head_msa, n_feat=d_msa,
r_ff=r_ff, p_drop=p_drop,
performer_N_opts=performer_N_opts,
performer_L_opts=performer_L_opts)
self.msa2pair = MSA2Pair(n_feat=d_msa, n_feat_out=d_pair, n_feat_proj=32,
n_resblock=n_resblock, p_drop=p_drop, n_att_head=n_head_msa)
self.pair2pair = Pair2Pair(n_layer=n_layer, n_att_head=n_head_pair,
n_feat=d_pair, r_ff=r_ff, p_drop=p_drop,
performer_L_opts=performer_L_opts)
self.pair2msa = Pair2MSA(n_layer=n_layer, n_att_head=4,
n_feat_in=d_pair, n_feat_out=d_msa, r_ff=r_ff, p_drop=p_drop)
self.str2str = Str2Str(d_msa=d_msa, d_pair=d_pair, SE3_param=SE3_param, p_drop=p_drop)
self.norm_state = LayerNorm(SE3_param['l0_out_features'])
self.pred_lddt = nn.Linear(SE3_param['l0_out_features'], 1)
def forward(self, msa, pair, xyz, seq1hot, idx):
# input:
# msa: initial MSA embeddings (N, L, d_msa)
# pair: initial residue pair embeddings (L, L, d_pair)
# 1. process MSA features
msa, att = self.msa2msa(msa)
# 2. update pair features using given MSA
pair = self.msa2pair(msa, pair, att)
# 3. process pair features
pair = self.pair2pair(pair)
msa = self.pair2msa(pair, msa)
xyz, state = self.str2str(msa.float(), pair.float(), xyz.float(), seq1hot, idx, top_k=32)
lddt = self.pred_lddt(self.norm_state(state))
return msa, pair, xyz, lddt.squeeze(-1)
class IterativeFeatureExtractor(nn.Module):
def __init__(self, n_module=4, n_module_str=4, n_layer=4, d_msa=256, d_pair=128, d_hidden=64,
n_head_msa=8, n_head_pair=8, r_ff=4,
n_resblock=1, p_drop=0.1,
performer_L_opts=None, performer_N_opts=None,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}):
super(IterativeFeatureExtractor, self).__init__()
self.n_module = n_module
self.n_module_str = n_module_str
#
self.initial = Pair2Pair(n_layer=n_layer, n_att_head=n_head_pair,
n_feat=d_pair, r_ff=r_ff, p_drop=p_drop,
performer_L_opts=performer_L_opts)
if self.n_module > 0:
self.iter_block_1 = _get_clones(IterBlock(n_layer=n_layer,
d_msa=d_msa, d_pair=d_pair,
n_head_msa=n_head_msa,
n_head_pair=n_head_pair,
r_ff=r_ff,
n_resblock=n_resblock,
p_drop=p_drop,
performer_N_opts=performer_N_opts,
performer_L_opts=performer_L_opts
), n_module)
self.init_str = InitStr_Network(node_dim_in=d_msa, node_dim_hidden=d_hidden,
edge_dim_in=d_pair, edge_dim_hidden=d_hidden,
nheads=4, nblocks=3, dropout=p_drop)
if self.n_module_str > 0:
self.iter_block_2 = _get_clones(IterBlock_w_Str(n_layer=n_layer,
d_msa=d_msa, d_pair=d_pair,
n_head_msa=n_head_msa,
n_head_pair=n_head_pair,
r_ff=r_ff,
n_resblock=n_resblock,
p_drop=p_drop,
performer_N_opts=performer_N_opts,
performer_L_opts=performer_L_opts,
SE3_param=SE3_param
), n_module_str)
self.final = FinalBlock(n_layer=n_layer, d_msa=d_msa, d_pair=d_pair,
n_head_msa=n_head_msa, n_head_pair=n_head_pair, r_ff=r_ff,
n_resblock=n_resblock, p_drop=p_drop,
performer_L_opts=performer_L_opts, performer_N_opts=performer_N_opts,
SE3_param=SE3_param)
def forward(self, msa, pair, seq1hot, idx):
# input:
# msa: initial MSA embeddings (N, L, d_msa)
# pair: initial residue pair embeddings (L, L, d_pair)
pair_s = list()
pair = self.initial(pair)
if self.n_module > 0:
for i_m in range(self.n_module):
# extract features from MSA & update original pair features
msa, pair = self.iter_block_1[i_m](msa, pair)
xyz = self.init_str(seq1hot, idx, msa, pair)
top_ks = [128, 128, 64, 64]
if self.n_module_str > 0:
for i_m in range(self.n_module_str):
msa, pair, xyz = self.iter_block_2[i_m](msa, pair, xyz, seq1hot, idx, top_k=top_ks[i_m])
msa, pair, xyz, lddt = self.final(msa, pair, xyz, seq1hot, idx)
return msa[:,0], pair, xyz, lddt

View file

@ -0,0 +1,36 @@
import torch
import torch.nn as nn
from resnet import ResidualNetwork
from Transformer import LayerNorm
# predict distance map from pair features
# based on simple 2D ResNet
class DistanceNetwork(nn.Module):
def __init__(self, n_feat, n_block=1, block_type='orig', p_drop=0.0):
super(DistanceNetwork, self).__init__()
self.norm = LayerNorm(n_feat)
self.proj = nn.Linear(n_feat, n_feat)
self.drop = nn.Dropout(p_drop)
#
self.resnet_dist = ResidualNetwork(n_block, n_feat, n_feat, 37, block_type=block_type, p_drop=p_drop)
self.resnet_omega = ResidualNetwork(n_block, n_feat, n_feat, 37, block_type=block_type, p_drop=p_drop)
self.resnet_theta = ResidualNetwork(n_block, n_feat, n_feat, 37, block_type=block_type, p_drop=p_drop)
self.resnet_phi = ResidualNetwork(n_block, n_feat, n_feat, 19, block_type=block_type, p_drop=p_drop)
def forward(self, x):
# input: pair info (B, L, L, C)
x = self.norm(x)
x = self.drop(self.proj(x))
x = x.permute(0,3,1,2).contiguous()
# predict theta, phi (non-symmetric)
logits_theta = self.resnet_theta(x)
logits_phi = self.resnet_phi(x)
# predict dist, omega
x = 0.5 * (x + x.permute(0,1,3,2))
logits_dist = self.resnet_dist(x)
logits_omega = self.resnet_omega(x)
return logits_dist, logits_omega, logits_theta, logits_phi

View file

@ -0,0 +1,175 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from Transformer import EncoderLayer, AxialEncoderLayer, Encoder, LayerNorm
# Initial embeddings for target sequence, msa, template info
# positional encoding
# option 1: using sin/cos --> using this for now
# option 2: learn positional embedding
class PositionalEncoding(nn.Module):
def __init__(self, d_model, p_drop=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.drop = nn.Dropout(p_drop,inplace=True)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe) # (1, max_len, d_model)
def forward(self, x, idx_s):
pe = list()
for idx in idx_s:
pe.append(self.pe[:,idx,:])
pe = torch.stack(pe)
x = x + torch.autograd.Variable(pe, requires_grad=False)
return self.drop(x)
class PositionalEncoding2D(nn.Module):
def __init__(self, d_model, p_drop=0.1):
super(PositionalEncoding2D, self).__init__()
self.drop = nn.Dropout(p_drop,inplace=True)
#
d_model_half = d_model // 2
div_term = torch.exp(torch.arange(0., d_model_half, 2) *
-(math.log(10000.0) / d_model_half))
self.register_buffer('div_term', div_term)
def forward(self, x, idx_s):
B, L, _, K = x.shape
K_half = K//2
pe = torch.zeros_like(x)
i_batch = -1
for idx in idx_s:
i_batch += 1
sin_inp = idx.unsqueeze(1) * self.div_term
emb = torch.cat((sin_inp.sin(), sin_inp.cos()), dim=-1) # (L, K//2)
pe[i_batch,:,:,:K_half] = emb.unsqueeze(1)
pe[i_batch,:,:,K_half:] = emb.unsqueeze(0)
x = x + torch.autograd.Variable(pe, requires_grad=False)
return self.drop(x)
class QueryEncoding(nn.Module):
def __init__(self, d_model):
super(QueryEncoding, self).__init__()
self.pe = nn.Embedding(2, d_model) # (0 for query, 1 for others)
def forward(self, x):
B, N, L, K = x.shape
idx = torch.ones((B, N, L), device=x.device).long()
idx[:,0,:] = 0 # first sequence is the query
x = x + self.pe(idx)
return x
class MSA_emb(nn.Module):
def __init__(self, d_model=64, d_msa=21, p_drop=0.1, max_len=5000):
super(MSA_emb, self).__init__()
self.emb = nn.Embedding(d_msa, d_model)
self.pos = PositionalEncoding(d_model, p_drop=p_drop, max_len=max_len)
self.pos_q = QueryEncoding(d_model)
def forward(self, msa, idx):
B, N, L = msa.shape
out = self.emb(msa) # (B, N, L, K//2)
out = self.pos(out, idx) # add positional encoding
return self.pos_q(out) # add query encoding
# pixel-wise attention based embedding (from trRosetta-tbm)
class Templ_emb(nn.Module):
def __init__(self, d_t1d=3, d_t2d=10, d_templ=64, n_att_head=4, r_ff=4,
performer_opts=None, p_drop=0.1, max_len=5000):
super(Templ_emb, self).__init__()
self.proj = nn.Linear(d_t1d*2+d_t2d+1, d_templ)
self.pos = PositionalEncoding2D(d_templ, p_drop=p_drop)
# attention along L
enc_layer_L = AxialEncoderLayer(d_templ, d_templ*r_ff, n_att_head, p_drop=p_drop,
performer_opts=performer_opts)
self.encoder_L = Encoder(enc_layer_L, 1)
self.norm = LayerNorm(d_templ)
self.to_attn = nn.Linear(d_templ, 1)
def forward(self, t1d, t2d, idx):
# Input
# - t1d: 1D template info (B, T, L, 2)
# - t2d: 2D template info (B, T, L, L, 10)
B, T, L, _ = t1d.shape
left = t1d.unsqueeze(3).expand(-1,-1,-1,L,-1)
right = t1d.unsqueeze(2).expand(-1,-1,L,-1,-1)
seqsep = torch.abs(idx[:,:,None]-idx[:,None,:]) + 1
seqsep = torch.log(seqsep.float()).view(B,L,L,1).unsqueeze(1).expand(-1,T,-1,-1,-1)
#
feat = torch.cat((t2d, left, right, seqsep), -1)
feat = self.proj(feat).reshape(B*T, L, L, -1)
tmp = self.pos(feat, idx) # add positional embedding
#
# attention along L
feat = torch.empty_like(tmp)
for i_f in range(tmp.shape[0]):
feat[i_f] = self.encoder_L(tmp[i_f].view(1,L,L,-1))
del tmp
feat = feat.reshape(B, T, L, L, -1)
feat = feat.permute(0,2,3,1,4).contiguous().reshape(B, L*L, T, -1)
attn = self.to_attn(self.norm(feat))
attn = F.softmax(attn, dim=-2) # (B, L*L, T, 1)
feat = torch.matmul(attn.transpose(-2, -1), feat)
return feat.reshape(B, L, L, -1)
class Pair_emb_w_templ(nn.Module):
def __init__(self, d_model=128, d_seq=21, d_templ=64, p_drop=0.1):
super(Pair_emb_w_templ, self).__init__()
self.d_model = d_model
self.d_emb = d_model // 2
self.emb = nn.Embedding(d_seq, self.d_emb)
self.norm_templ = LayerNorm(d_templ)
self.projection = nn.Linear(d_model + d_templ + 1, d_model)
self.pos = PositionalEncoding2D(d_model, p_drop=p_drop)
def forward(self, seq, idx, templ):
# input:
# seq: target sequence (B, L, 20)
B = seq.shape[0]
L = seq.shape[1]
#
# get initial sequence pair features
seq = self.emb(seq) # (B, L, d_model//2)
left = seq.unsqueeze(2).expand(-1,-1,L,-1)
right = seq.unsqueeze(1).expand(-1,L,-1,-1)
seqsep = torch.abs(idx[:,:,None]-idx[:,None,:])+1
seqsep = torch.log(seqsep.float()).view(B,L,L,1)
#
templ = self.norm_templ(templ)
pair = torch.cat((left, right, seqsep, templ), dim=-1)
pair = self.projection(pair) # (B, L, L, d_model)
return self.pos(pair, idx)
class Pair_emb_wo_templ(nn.Module):
#TODO: embedding without template info
def __init__(self, d_model=128, d_seq=21, p_drop=0.1):
super(Pair_emb_wo_templ, self).__init__()
self.d_model = d_model
self.d_emb = d_model // 2
self.emb = nn.Embedding(d_seq, self.d_emb)
self.projection = nn.Linear(d_model + 1, d_model)
self.pos = PositionalEncoding2D(d_model, p_drop=p_drop)
def forward(self, seq, idx):
# input:
# seq: target sequence (B, L, 20)
B = seq.shape[0]
L = seq.shape[1]
seq = self.emb(seq) # (B, L, d_model//2)
left = seq.unsqueeze(2).expand(-1,-1,L,-1)
right = seq.unsqueeze(1).expand(-1,L,-1,-1)
seqsep = torch.abs(idx[:,:,None]-idx[:,None,:])+1
seqsep = torch.log(seqsep.float()).view(B,L,L,1)
#
pair = torch.cat((left, right, seqsep), dim=-1)
pair = self.projection(pair)
return self.pos(pair, idx)

View file

@ -0,0 +1,116 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from Transformer import LayerNorm, SequenceWeight
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import TransformerConv
def get_seqsep(idx):
'''
Input:
- idx: residue indices of given sequence (B,L)
Output:
- seqsep: sequence separation feature with sign (B, L, L, 1)
Sergey found that having sign in seqsep features helps a little
'''
seqsep = idx[:,None,:] - idx[:,:,None]
sign = torch.sign(seqsep)
seqsep = torch.log(torch.abs(seqsep) + 1.0)
seqsep = torch.clamp(seqsep, 0.0, 5.5)
seqsep = sign * seqsep
return seqsep.unsqueeze(-1)
def make_graph(node, idx, emb):
''' create torch_geometric graph from Trunk outputs '''
device = emb.device
B, L = emb.shape[:2]
# |i-j| <= kmin (connect sequentially adjacent residues)
sep = idx[:,None,:] - idx[:,:,None]
sep = sep.abs()
b, i, j = torch.where(sep > 0)
src = b*L+i
tgt = b*L+j
x = node.reshape(B*L, -1)
G = Data(x=x,
edge_index=torch.stack([src,tgt]),
edge_attr=emb[b,i,j])
return G
class UniMPBlock(nn.Module):
'''https://arxiv.org/pdf/2009.03509.pdf'''
def __init__(self,
node_dim=64,
edge_dim=64,
heads=4,
dropout=0.15):
super(UniMPBlock, self).__init__()
self.TConv = TransformerConv(node_dim, node_dim, heads, dropout=dropout, edge_dim=edge_dim)
self.LNorm = LayerNorm(node_dim*heads)
self.Linear = nn.Linear(node_dim*heads, node_dim)
self.Activ = nn.ELU(inplace=True)
#@torch.cuda.amp.autocast(enabled=True)
def forward(self, G):
xin, e_idx, e_attr = G.x, G.edge_index, G.edge_attr
x = self.TConv(xin, e_idx, e_attr)
x = self.LNorm(x)
x = self.Linear(x)
out = self.Activ(x+xin)
return Data(x=out, edge_index=e_idx, edge_attr=e_attr)
class InitStr_Network(nn.Module):
def __init__(self,
node_dim_in=64,
node_dim_hidden=64,
edge_dim_in=128,
edge_dim_hidden=64,
nheads=4,
nblocks=3,
dropout=0.1):
super(InitStr_Network, self).__init__()
# embedding layers for node and edge features
self.norm_node = LayerNorm(node_dim_in)
self.norm_edge = LayerNorm(edge_dim_in)
self.encoder_seq = SequenceWeight(node_dim_in, 1, dropout=dropout)
self.embed_x = nn.Sequential(nn.Linear(node_dim_in+21, node_dim_hidden), nn.ELU(inplace=True))
self.embed_e = nn.Sequential(nn.Linear(edge_dim_in+1, edge_dim_hidden), nn.ELU(inplace=True))
# graph transformer
blocks = [UniMPBlock(node_dim_hidden,edge_dim_hidden,nheads,dropout) for _ in range(nblocks)]
self.transformer = nn.Sequential(*blocks)
# outputs
self.get_xyz = nn.Linear(node_dim_hidden,9)
def forward(self, seq1hot, idx, msa, pair):
B, N, L = msa.shape[:3]
msa = self.norm_node(msa)
pair = self.norm_edge(pair)
w_seq = self.encoder_seq(msa).reshape(B, L, 1, N).permute(0,3,1,2)
msa = w_seq*msa
msa = msa.sum(dim=1)
node = torch.cat((msa, seq1hot), dim=-1)
node = self.embed_x(node)
seqsep = get_seqsep(idx)
pair = torch.cat((pair, seqsep), dim=-1)
pair = self.embed_e(pair)
G = make_graph(node, idx, pair)
Gout = self.transformer(G)
xyz = self.get_xyz(Gout.x)
return xyz.reshape(B, L, 3, 3) #torch.cat([xyz,node_emb],dim=-1)

View file

@ -0,0 +1,175 @@
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from Transformer import LayerNorm
from InitStrGenerator import make_graph
from InitStrGenerator import get_seqsep, UniMPBlock
from Attention_module_w_str import make_graph as make_graph_topk
from Attention_module_w_str import get_bonded_neigh, rbf
from SE3_network import SE3Transformer
from Transformer import _get_clones, create_custom_forward
# Re-generate initial coordinates based on 1) final pair features 2) predicted distogram
# Then, refine it through multiple SE3 transformer block
class Regen_Network(nn.Module):
def __init__(self,
node_dim_in=64,
node_dim_hidden=64,
edge_dim_in=128,
edge_dim_hidden=64,
state_dim=8,
nheads=4,
nblocks=3,
dropout=0.0):
super(Regen_Network, self).__init__()
# embedding layers for node and edge features
self.norm_node = LayerNorm(node_dim_in)
self.norm_edge = LayerNorm(edge_dim_in)
self.embed_x = nn.Sequential(nn.Linear(node_dim_in+21, node_dim_hidden), LayerNorm(node_dim_hidden))
self.embed_e = nn.Sequential(nn.Linear(edge_dim_in+2, edge_dim_hidden), LayerNorm(edge_dim_hidden))
# graph transformer
blocks = [UniMPBlock(node_dim_hidden,edge_dim_hidden,nheads,dropout) for _ in range(nblocks)]
self.transformer = nn.Sequential(*blocks)
# outputs
self.get_xyz = nn.Linear(node_dim_hidden,9)
self.norm_state = LayerNorm(node_dim_hidden)
self.get_state = nn.Linear(node_dim_hidden, state_dim)
def forward(self, seq1hot, idx, node, edge):
B, L = node.shape[:2]
node = self.norm_node(node)
edge = self.norm_edge(edge)
node = torch.cat((node, seq1hot), dim=-1)
node = self.embed_x(node)
seqsep = get_seqsep(idx)
neighbor = get_bonded_neigh(idx)
edge = torch.cat((edge, seqsep, neighbor), dim=-1)
edge = self.embed_e(edge)
G = make_graph(node, idx, edge)
Gout = self.transformer(G)
xyz = self.get_xyz(Gout.x)
state = self.get_state(self.norm_state(Gout.x))
return xyz.reshape(B, L, 3, 3) , state.reshape(B, L, -1)
class Refine_Network(nn.Module):
def __init__(self, d_node=64, d_pair=128, d_state=16,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, p_drop=0.0):
super(Refine_Network, self).__init__()
self.norm_msa = LayerNorm(d_node)
self.norm_pair = LayerNorm(d_pair)
self.norm_state = LayerNorm(d_state)
self.embed_x = nn.Linear(d_node+21+d_state, SE3_param['l0_in_features'])
self.embed_e1 = nn.Linear(d_pair, SE3_param['num_edge_features'])
self.embed_e2 = nn.Linear(SE3_param['num_edge_features']+36+1, SE3_param['num_edge_features'])
self.norm_node = LayerNorm(SE3_param['l0_in_features'])
self.norm_edge1 = LayerNorm(SE3_param['num_edge_features'])
self.norm_edge2 = LayerNorm(SE3_param['num_edge_features'])
self.se3 = SE3Transformer(**SE3_param)
@torch.cuda.amp.autocast(enabled=False)
def forward(self, msa, pair, xyz, state, seq1hot, idx, top_k=64):
# process node & pair features
B, L = msa.shape[:2]
node = self.norm_msa(msa)
pair = self.norm_pair(pair)
state = self.norm_state(state)
node = torch.cat((node, seq1hot, state), dim=-1)
node = self.norm_node(self.embed_x(node))
pair = self.norm_edge1(self.embed_e1(pair))
neighbor = get_bonded_neigh(idx)
rbf_feat = rbf(torch.cdist(xyz[:,:,1,:], xyz[:,:,1,:]))
pair = torch.cat((pair, rbf_feat, neighbor), dim=-1)
pair = self.norm_edge2(self.embed_e2(pair))
# define graph
G = make_graph_topk(xyz, pair, idx, top_k=top_k)
l1_feats = xyz - xyz[:,:,1,:].unsqueeze(2) # l1 features = displacement vector to CA
l1_feats = l1_feats.reshape(B*L, -1, 3)
# apply SE(3) Transformer & update coordinates
shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats)
state = shift['0'].reshape(B, L, -1) # (B, L, C)
offset = shift['1'].reshape(B, L, -1, 3) # (B, L, 3, 3)
CA_new = xyz[:,:,1] + offset[:,:,1]
N_new = CA_new + offset[:,:,0]
C_new = CA_new + offset[:,:,2]
xyz_new = torch.stack([N_new, CA_new, C_new], dim=2)
return xyz_new, state
class Refine_module(nn.Module):
def __init__(self, n_module, d_node=64, d_node_hidden=64, d_pair=128, d_pair_hidden=64,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, p_drop=0.0):
super(Refine_module, self).__init__()
self.n_module = n_module
self.proj_edge = nn.Linear(d_pair, d_pair_hidden*2)
self.regen_net = Regen_Network(node_dim_in=d_node, node_dim_hidden=d_node_hidden,
edge_dim_in=d_pair_hidden*2, edge_dim_hidden=d_pair_hidden,
state_dim=SE3_param['l0_out_features'],
nheads=4, nblocks=3, dropout=p_drop)
self.refine_net = _get_clones(Refine_Network(d_node=d_node, d_pair=d_pair_hidden*2,
d_state=SE3_param['l0_out_features'],
SE3_param=SE3_param, p_drop=p_drop), self.n_module)
self.norm_state = LayerNorm(SE3_param['l0_out_features'])
self.pred_lddt = nn.Linear(SE3_param['l0_out_features'], 1)
def forward(self, node, edge, seq1hot, idx, use_transf_checkpoint=False, eps=1e-4):
edge = self.proj_edge(edge)
xyz, state = self.regen_net(seq1hot, idx, node, edge)
# DOUBLE IT w/ Mirror images
xyz = torch.cat([xyz, xyz*torch.tensor([1,1,-1], dtype=xyz.dtype, device=xyz.device)])
state = torch.cat([state, state])
node = torch.cat([node, node])
edge = torch.cat([edge, edge])
idx = torch.cat([idx, idx])
seq1hot = torch.cat([seq1hot, seq1hot])
best_xyz = xyz
best_lddt = torch.zeros((xyz.shape[0], xyz.shape[1], 1), device=xyz.device)
prev_lddt = 0.0
no_impr = 0
no_impr_best = 0
for i_iter in range(200):
for i_m in range(self.n_module):
if use_transf_checkpoint:
xyz, state = checkpoint.checkpoint(create_custom_forward(self.refine_net[i_m], top_k=64), node.float(), edge.float(), xyz.detach().float(), state.float(), seq1hot, idx)
else:
xyz, state = self.refine_net[i_m](node.float(), edge.float(), xyz.detach().float(), state.float(), seq1hot, idx, top_k=64)
#
lddt = self.pred_lddt(self.norm_state(state))
lddt = torch.clamp(lddt, 0.0, 1.0)[...,0]
print (f"SE(3) iteration {i_iter} {lddt.mean(-1).cpu().numpy()}")
if lddt.mean(-1).max() <= prev_lddt+eps:
no_impr += 1
else:
no_impr = 0
if lddt.mean(-1).max() <= best_lddt.mean(-1).max()+eps:
no_impr_best += 1
else:
no_impr_best = 0
if no_impr > 10 or no_impr_best > 20:
break
if lddt.mean(-1).max() > best_lddt.mean(-1).max():
best_lddt = lddt
best_xyz = xyz
prev_lddt = lddt.mean(-1).max()
pick = best_lddt.mean(-1).argmax()
return best_xyz[pick][None], best_lddt[pick][None]

View file

@ -0,0 +1,131 @@
import torch
import torch.nn as nn
from Embeddings import MSA_emb, Pair_emb_wo_templ, Pair_emb_w_templ, Templ_emb
from Attention_module_w_str import IterativeFeatureExtractor
from DistancePredictor import DistanceNetwork
from Refine_module import Refine_module
class RoseTTAFoldModule(nn.Module):
def __init__(self, n_module=4, n_module_str=4, n_layer=4,\
d_msa=64, d_pair=128, d_templ=64,\
n_head_msa=4, n_head_pair=8, n_head_templ=4,
d_hidden=64, r_ff=4, n_resblock=1, p_drop=0.1,
performer_L_opts=None, performer_N_opts=None,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
use_templ=False):
super(RoseTTAFoldModule, self).__init__()
self.use_templ = use_templ
#
self.msa_emb = MSA_emb(d_model=d_msa, p_drop=p_drop, max_len=5000)
if use_templ:
self.templ_emb = Templ_emb(d_templ=d_templ, n_att_head=n_head_templ, r_ff=r_ff,
performer_opts=performer_L_opts, p_drop=0.0)
self.pair_emb = Pair_emb_w_templ(d_model=d_pair, d_templ=d_templ, p_drop=p_drop)
else:
self.pair_emb = Pair_emb_wo_templ(d_model=d_pair, p_drop=p_drop)
#
self.feat_extractor = IterativeFeatureExtractor(n_module=n_module,\
n_module_str=n_module_str,\
n_layer=n_layer,\
d_msa=d_msa, d_pair=d_pair, d_hidden=d_hidden,\
n_head_msa=n_head_msa, \
n_head_pair=n_head_pair,\
r_ff=r_ff, \
n_resblock=n_resblock,
p_drop=p_drop,
performer_N_opts=performer_N_opts,
performer_L_opts=performer_L_opts,
SE3_param=SE3_param)
self.c6d_predictor = DistanceNetwork(d_pair, p_drop=p_drop)
def forward(self, msa, seq, idx, t1d=None, t2d=None):
B, N, L = msa.shape
# Get embeddings
msa = self.msa_emb(msa, idx)
if self.use_templ:
tmpl = self.templ_emb(t1d, t2d, idx)
pair = self.pair_emb(seq, idx, tmpl)
else:
pair = self.pair_emb(seq, idx)
#
# Extract features
seq1hot = torch.nn.functional.one_hot(seq, num_classes=21).float()
msa, pair, xyz, lddt = self.feat_extractor(msa, pair, seq1hot, idx)
# Predict 6D coords
logits = self.c6d_predictor(pair)
return logits, xyz, lddt.view(B, L)
class RoseTTAFoldModule_e2e(nn.Module):
def __init__(self, n_module=4, n_module_str=4, n_module_ref=4, n_layer=4,\
d_msa=64, d_pair=128, d_templ=64,\
n_head_msa=4, n_head_pair=8, n_head_templ=4,
d_hidden=64, r_ff=4, n_resblock=1, p_drop=0.0,
performer_L_opts=None, performer_N_opts=None,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
REF_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
use_templ=False):
super(RoseTTAFoldModule_e2e, self).__init__()
self.use_templ = use_templ
#
self.msa_emb = MSA_emb(d_model=d_msa, p_drop=p_drop, max_len=5000)
if use_templ:
self.templ_emb = Templ_emb(d_templ=d_templ, n_att_head=n_head_templ, r_ff=r_ff,
performer_opts=performer_L_opts, p_drop=0.0)
self.pair_emb = Pair_emb_w_templ(d_model=d_pair, d_templ=d_templ, p_drop=p_drop)
else:
self.pair_emb = Pair_emb_wo_templ(d_model=d_pair, p_drop=p_drop)
#
self.feat_extractor = IterativeFeatureExtractor(n_module=n_module,\
n_module_str=n_module_str,\
n_layer=n_layer,\
d_msa=d_msa, d_pair=d_pair, d_hidden=d_hidden,\
n_head_msa=n_head_msa, \
n_head_pair=n_head_pair,\
r_ff=r_ff, \
n_resblock=n_resblock,
p_drop=p_drop,
performer_N_opts=performer_N_opts,
performer_L_opts=performer_L_opts,
SE3_param=SE3_param)
self.c6d_predictor = DistanceNetwork(d_pair, p_drop=p_drop)
#
self.refine = Refine_module(n_module_ref, d_node=d_msa, d_pair=130,
d_node_hidden=d_hidden, d_pair_hidden=d_hidden,
SE3_param=REF_param, p_drop=p_drop)
def forward(self, msa, seq, idx, t1d=None, t2d=None, prob_s=None, return_raw=False, refine_only=False):
seq1hot = torch.nn.functional.one_hot(seq, num_classes=21).float()
if not refine_only:
B, N, L = msa.shape
# Get embeddings
msa = self.msa_emb(msa, idx)
if self.use_templ:
tmpl = self.templ_emb(t1d, t2d, idx)
pair = self.pair_emb(seq, idx, tmpl)
else:
pair = self.pair_emb(seq, idx)
#
# Extract features
msa, pair, xyz, lddt = self.feat_extractor(msa, pair, seq1hot, idx)
# Predict 6D coords
logits = self.c6d_predictor(pair)
prob_s = list()
for l in logits:
prob_s.append(nn.Softmax(dim=1)(l)) # (B, C, L, L)
prob_s = torch.cat(prob_s, dim=1).permute(0,2,3,1)
B, L = msa.shape[:2]
if return_raw:
return logits, msa, xyz, lddt.view(B, L)
ref_xyz, ref_lddt = self.refine(msa, prob_s, seq1hot, idx)
if refine_only:
return ref_xyz, ref_lddt.view(B,L)
else:
return logits, msa, ref_xyz, ref_lddt.view(B,L)

View file

@ -0,0 +1,108 @@
import torch
import torch.nn as nn
from equivariant_attention.modules import get_basis_and_r, GSE3Res, GNormBias
from equivariant_attention.modules import GConvSE3, GNormSE3
from equivariant_attention.fibers import Fiber
class TFN(nn.Module):
"""SE(3) equivariant GCN"""
def __init__(self, num_layers=2, num_channels=32, num_nonlin_layers=1, num_degrees=3,
l0_in_features=32, l0_out_features=32,
l1_in_features=3, l1_out_features=3,
num_edge_features=32, use_self=True):
super().__init__()
# Build the network
self.num_layers = num_layers
self.num_nlayers = num_nonlin_layers
self.num_channels = num_channels
self.num_degrees = num_degrees
self.edge_dim = num_edge_features
self.use_self = use_self
if l1_out_features > 0:
fibers = {'in': Fiber(dictionary={0: l0_in_features, 1: l1_in_features}),
'mid': Fiber(self.num_degrees, self.num_channels),
'out': Fiber(dictionary={0: l0_out_features, 1: l1_out_features})}
else:
fibers = {'in': Fiber(dictionary={0: l0_in_features, 1: l1_in_features}),
'mid': Fiber(self.num_degrees, self.num_channels),
'out': Fiber(dictionary={0: l0_out_features})}
blocks = self._build_gcn(fibers)
self.block0 = blocks
def _build_gcn(self, fibers):
block0 = []
fin = fibers['in']
for i in range(self.num_layers-1):
block0.append(GConvSE3(fin, fibers['mid'], self_interaction=self.use_self, edge_dim=self.edge_dim))
block0.append(GNormSE3(fibers['mid'], num_layers=self.num_nlayers))
fin = fibers['mid']
block0.append(GConvSE3(fibers['mid'], fibers['out'], self_interaction=self.use_self, edge_dim=self.edge_dim))
return nn.ModuleList(block0)
@torch.cuda.amp.autocast(enabled=False)
def forward(self, G, type_0_features, type_1_features):
# Compute equivariant weight basis from relative positions
basis, r = get_basis_and_r(G, self.num_degrees-1)
h = {'0': type_0_features, '1': type_1_features}
for layer in self.block0:
h = layer(h, G=G, r=r, basis=basis)
return h
class SE3Transformer(nn.Module):
"""SE(3) equivariant GCN with attention"""
def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4,
si_m='1x1', si_e='att',
l0_in_features=32, l0_out_features=32,
l1_in_features=3, l1_out_features=3,
num_edge_features=32, x_ij=None):
super().__init__()
# Build the network
self.num_layers = num_layers
self.num_channels = num_channels
self.num_degrees = num_degrees
self.edge_dim = num_edge_features
self.div = div
self.n_heads = n_heads
self.si_m, self.si_e = si_m, si_e
self.x_ij = x_ij
if l1_out_features > 0:
fibers = {'in': Fiber(dictionary={0: l0_in_features, 1: l1_in_features}),
'mid': Fiber(self.num_degrees, self.num_channels),
'out': Fiber(dictionary={0: l0_out_features, 1: l1_out_features})}
else:
fibers = {'in': Fiber(dictionary={0: l0_in_features, 1: l1_in_features}),
'mid': Fiber(self.num_degrees, self.num_channels),
'out': Fiber(dictionary={0: l0_out_features})}
blocks = self._build_gcn(fibers)
self.Gblock = blocks
def _build_gcn(self, fibers):
# Equivariant layers
Gblock = []
fin = fibers['in']
for i in range(self.num_layers):
Gblock.append(GSE3Res(fin, fibers['mid'], edge_dim=self.edge_dim,
div=self.div, n_heads=self.n_heads,
learnable_skip=True, skip='cat',
selfint=self.si_m, x_ij=self.x_ij))
Gblock.append(GNormBias(fibers['mid']))
fin = fibers['mid']
Gblock.append(
GSE3Res(fibers['mid'], fibers['out'], edge_dim=self.edge_dim,
div=1, n_heads=min(1, 2), learnable_skip=True,
skip='cat', selfint=self.si_e, x_ij=self.x_ij))
return nn.ModuleList(Gblock)
@torch.cuda.amp.autocast(enabled=False)
def forward(self, G, type_0_features, type_1_features):
# Compute equivariant weight basis from relative positions
basis, r = get_basis_and_r(G, self.num_degrees-1)
h = {'0': type_0_features, '1': type_1_features}
for layer in self.Gblock:
h = layer(h, G=G, r=r, basis=basis)
return h

View file

@ -0,0 +1,480 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
from performer_pytorch import SelfAttention
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
# for gradient checkpointing
def create_custom_forward(module, **kwargs):
def custom_forward(*inputs):
return module(*inputs, **kwargs)
return custom_forward
class LayerNorm(nn.Module):
def __init__(self, d_model, eps=1e-5):
super(LayerNorm, self).__init__()
self.a_2 = nn.Parameter(torch.ones(d_model))
self.b_2 = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = torch.sqrt(x.var(dim=-1, keepdim=True, unbiased=False) + self.eps)
x = self.a_2*(x-mean)
x /= std
x += self.b_2
return x
class FeedForwardLayer(nn.Module):
def __init__(self, d_model, d_ff, p_drop=0.1):
super(FeedForwardLayer, self).__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(p_drop, inplace=True)
self.linear2 = nn.Linear(d_ff, d_model)
def forward(self, src):
src = self.linear2(self.dropout(F.relu_(self.linear1(src))))
return src
class MultiheadAttention(nn.Module):
def __init__(self, d_model, heads, k_dim=None, v_dim=None, dropout=0.1):
super(MultiheadAttention, self).__init__()
if k_dim == None:
k_dim = d_model
if v_dim == None:
v_dim = d_model
self.heads = heads
self.d_model = d_model
self.d_k = d_model // heads
self.scaling = 1/math.sqrt(self.d_k)
self.to_query = nn.Linear(d_model, d_model)
self.to_key = nn.Linear(k_dim, d_model)
self.to_value = nn.Linear(v_dim, d_model)
self.to_out = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout, inplace=True)
def forward(self, query, key, value, return_att=False):
batch, L1 = query.shape[:2]
batch, L2 = key.shape[:2]
q = self.to_query(query).view(batch, L1, self.heads, self.d_k).permute(0,2,1,3) # (B, h, L, d_k)
k = self.to_key(key).view(batch, L2, self.heads, self.d_k).permute(0,2,1,3) # (B, h, L, d_k)
v = self.to_value(value).view(batch, L2, self.heads, self.d_k).permute(0,2,1,3)
#
attention = torch.matmul(q, k.transpose(-2, -1))*self.scaling
attention = F.softmax(attention, dim=-1) # (B, h, L1, L2)
attention = self.dropout(attention)
#
out = torch.matmul(attention, v) # (B, h, L, d_k)
out = out.permute(0,2,1,3).contiguous().view(batch, L1, -1)
#
out = self.to_out(out)
if return_att:
attention = 0.5*(attention + attention.permute(0,1,3,2))
return out, attention.permute(0,2,3,1)
return out
# Own implementation for tied multihead attention
class TiedMultiheadAttention(nn.Module):
def __init__(self, d_model, heads, k_dim=None, v_dim=None, dropout=0.1):
super(TiedMultiheadAttention, self).__init__()
if k_dim == None:
k_dim = d_model
if v_dim == None:
v_dim = d_model
self.heads = heads
self.d_model = d_model
self.d_k = d_model // heads
self.scaling = 1/math.sqrt(self.d_k)
self.to_query = nn.Linear(d_model, d_model)
self.to_key = nn.Linear(k_dim, d_model)
self.to_value = nn.Linear(v_dim, d_model)
self.to_out = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout, inplace=True)
def forward(self, query, key, value, return_att=False):
B, N, L = query.shape[:3]
q = self.to_query(query).view(B, N, L, self.heads, self.d_k).permute(0,1,3,2,4).contiguous() # (B, N, h, l, k)
k = self.to_key(key).view(B, N, L, self.heads, self.d_k).permute(0,1,3,4,2).contiguous() # (B, N, h, k, l)
v = self.to_value(value).view(B, N, L, self.heads, self.d_k).permute(0,1,3,2,4).contiguous() # (B, N, h, l, k)
#
#attention = torch.matmul(q, k.transpose(-2, -1))/math.sqrt(N*self.d_k) # (B, N, h, L, L)
#attention = attention.sum(dim=1) # tied attention (B, h, L, L)
scale = self.scaling / math.sqrt(N)
q = q * scale
attention = torch.einsum('bnhik,bnhkj->bhij', q, k)
attention = F.softmax(attention, dim=-1) # (B, h, L, L)
attention = self.dropout(attention)
attention = attention.unsqueeze(1) # (B, 1, h, L, L)
#
out = torch.matmul(attention, v) # (B, N, h, L, d_k)
out = out.permute(0,1,3,2,4).contiguous().view(B, N, L, -1)
#
out = self.to_out(out)
if return_att:
attention = attention.squeeze(1)
attention = 0.5*(attention + attention.permute(0,1,3,2))
attention = attention.permute(0,3,1,2)
return out, attention
return out
class SequenceWeight(nn.Module):
def __init__(self, d_model, heads, dropout=0.1):
super(SequenceWeight, self).__init__()
self.heads = heads
self.d_model = d_model
self.d_k = d_model // heads
self.scale = 1.0 / math.sqrt(self.d_k)
self.to_query = nn.Linear(d_model, d_model)
self.to_key = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout, inplace=True)
def forward(self, msa):
B, N, L = msa.shape[:3]
msa = msa.permute(0,2,1,3) # (B, L, N, K)
tar_seq = msa[:,:,0].unsqueeze(2) # (B, L, 1, K)
q = self.to_query(tar_seq).view(B, L, 1, self.heads, self.d_k).permute(0,1,3,2,4).contiguous() # (B, L, h, 1, k)
k = self.to_key(msa).view(B, L, N, self.heads, self.d_k).permute(0,1,3,4,2).contiguous() # (B, L, h, k, N)
q = q * self.scale
attn = torch.matmul(q, k) # (B, L, h, 1, N)
attn = F.softmax(attn, dim=-1)
return self.dropout(attn)
# Own implementation for multihead attention (Input shape: Batch, Len, Emb)
class SoftTiedMultiheadAttention(nn.Module):
def __init__(self, d_model, heads, k_dim=None, v_dim=None, dropout=0.1):
super(SoftTiedMultiheadAttention, self).__init__()
if k_dim == None:
k_dim = d_model
if v_dim == None:
v_dim = d_model
self.heads = heads
self.d_model = d_model
self.d_k = d_model // heads
self.scale = 1.0 / math.sqrt(self.d_k)
self.seq_weight = SequenceWeight(d_model, heads, dropout=dropout)
self.to_query = nn.Linear(d_model, d_model)
self.to_key = nn.Linear(k_dim, d_model)
self.to_value = nn.Linear(v_dim, d_model)
self.to_out = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout, inplace=True)
def forward(self, query, key, value, return_att=False):
B, N, L = query.shape[:3]
#
seq_weight = self.seq_weight(query) # (B, L, h, 1, N)
seq_weight = seq_weight.permute(0,4,2,1,3) # (B, N, h, l, -1)
#
q = self.to_query(query).view(B, N, L, self.heads, self.d_k).permute(0,1,3,2,4).contiguous() # (B, N, h, l, k)
k = self.to_key(key).view(B, N, L, self.heads, self.d_k).permute(0,1,3,4,2).contiguous() # (B, N, h, k, l)
v = self.to_value(value).view(B, N, L, self.heads, self.d_k).permute(0,1,3,2,4).contiguous() # (B, N, h, l, k)
#
#attention = torch.matmul(q, k.transpose(-2, -1))/math.sqrt(N*self.d_k) # (B, N, h, L, L)
#attention = attention.sum(dim=1) # tied attention (B, h, L, L)
q = q * seq_weight # (B, N, h, l, k)
k = k * self.scale
attention = torch.einsum('bnhik,bnhkj->bhij', q, k)
attention = F.softmax(attention, dim=-1) # (B, h, L, L)
attention = self.dropout(attention)
attention = attention # (B, 1, h, L, L)
del q, k, seq_weight
#
#out = torch.matmul(attention, v) # (B, N, h, L, d_k)
out = torch.einsum('bhij,bnhjk->bnhik', attention, v)
out = out.permute(0,1,3,2,4).contiguous().view(B, N, L, -1)
#
out = self.to_out(out)
if return_att:
attention = attention.squeeze(1)
attention = 0.5*(attention + attention.permute(0,1,3,2))
attention = attention.permute(0,2,3,1)
return out, attention
return out
class DirectMultiheadAttention(nn.Module):
def __init__(self, d_in, d_out, heads, dropout=0.1):
super(DirectMultiheadAttention, self).__init__()
self.heads = heads
self.proj_pair = nn.Linear(d_in, heads)
self.drop = nn.Dropout(dropout, inplace=True)
# linear projection to get values from given msa
self.proj_msa = nn.Linear(d_out, d_out)
# projection after applying attention
self.proj_out = nn.Linear(d_out, d_out)
def forward(self, src, tgt):
B, N, L = tgt.shape[:3]
attn_map = F.softmax(self.proj_pair(src), dim=1).permute(0,3,1,2) # (B, h, L, L)
attn_map = self.drop(attn_map).unsqueeze(1)
# apply attention
value = self.proj_msa(tgt).permute(0,3,1,2).contiguous().view(B, -1, self.heads, N, L) # (B,-1, h, N, L)
tgt = torch.matmul(value, attn_map).view(B, -1, N, L).permute(0,2,3,1) # (B,N,L,K)
tgt = self.proj_out(tgt)
return tgt
class MaskedDirectMultiheadAttention(nn.Module):
def __init__(self, d_in, d_out, heads, d_k=32, dropout=0.1):
super(MaskedDirectMultiheadAttention, self).__init__()
self.heads = heads
self.scaling = 1/math.sqrt(d_k)
self.to_query = nn.Linear(d_in, heads*d_k)
self.to_key = nn.Linear(d_in, heads*d_k)
self.to_value = nn.Linear(d_out, d_out)
self.to_out = nn.Linear(d_out, d_out)
self.dropout = nn.Dropout(dropout, inplace=True)
def forward(self, query, key, value, mask):
batch, N, L = value.shape[:3]
#
# project to query, key, value
q = self.to_query(query).view(batch, L, self.heads, -1).permute(0,2,1,3) # (B, h, L, -1)
k = self.to_key(key).view(batch, L, self.heads, -1).permute(0,2,1,3) # (B, h, L, -1)
v = self.to_value(value).view(batch, N, L, self.heads, -1).permute(0,3,1,2,4) # (B, h, N, L, -1)
#
q = q*self.scaling
attention = torch.matmul(q, k.transpose(-2, -1)) # (B, h, L, L)
attention = attention.masked_fill(mask < 0.5, torch.finfo(q.dtype).min)
attention = F.softmax(attention, dim=-1) # (B, h, L1, L2)
attention = self.dropout(attention) # (B, h, 1, L, L)
#
#out = torch.matmul(attention, v) # (B, h, N, L, d_out//h)
out = torch.einsum('bhij,bhnjk->bhnik', attention, v) # (B, h, N, L, d_out//h)
out = out.permute(0,2,3,1,4).contiguous().view(batch, N, L, -1)
#
out = self.to_out(out)
return out
# Use PreLayerNorm for more stable training
class EncoderLayer(nn.Module):
def __init__(self, d_model, d_ff, heads, p_drop=0.1, performer_opts=None, use_tied=False):
super(EncoderLayer, self).__init__()
self.use_performer = performer_opts is not None
self.use_tied = use_tied
# multihead attention
if self.use_performer:
self.attn = SelfAttention(dim=d_model, heads=heads, dropout=p_drop,
generalized_attention=True, **performer_opts)
elif use_tied:
self.attn = SoftTiedMultiheadAttention(d_model, heads, dropout=p_drop)
else:
self.attn = MultiheadAttention(d_model, heads, dropout=p_drop)
# feedforward
self.ff = FeedForwardLayer(d_model, d_ff, p_drop=p_drop)
# normalization module
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.dropout1 = nn.Dropout(p_drop, inplace=True)
self.dropout2 = nn.Dropout(p_drop, inplace=True)
def forward(self, src, return_att=False):
# Input shape for multihead attention: (BATCH, SRCLEN, EMB)
# multihead attention w/ pre-LayerNorm
B, N, L = src.shape[:3]
src2 = self.norm1(src)
if not self.use_tied:
src2 = src2.reshape(B*N, L, -1)
if return_att:
src2, att = self.attn(src2, src2, src2, return_att=return_att)
src2 = src2.reshape(B,N,L,-1)
else:
src2 = self.attn(src2, src2, src2).reshape(B,N,L,-1)
src = src + self.dropout1(src2)
# feed-forward
src2 = self.norm2(src) # pre-normalization
src2 = self.ff(src2)
src = src + self.dropout2(src2)
if return_att:
return src, att
return src
# AxialTransformer with tied attention for L dimension
class AxialEncoderLayer(nn.Module):
def __init__(self, d_model, d_ff, heads, p_drop=0.1, performer_opts=None,
use_tied_row=False, use_tied_col=False, use_soft_row=False):
super(AxialEncoderLayer, self).__init__()
self.use_performer = performer_opts is not None
self.use_tied_row = use_tied_row
self.use_tied_col = use_tied_col
self.use_soft_row = use_soft_row
# multihead attention
if use_tied_row:
self.attn_L = TiedMultiheadAttention(d_model, heads, dropout=p_drop)
elif use_soft_row:
self.attn_L = SoftTiedMultiheadAttention(d_model, heads, dropout=p_drop)
else:
if self.use_performer:
self.attn_L = SelfAttention(dim=d_model, heads=heads, dropout=p_drop,
generalized_attention=True, **performer_opts)
else:
self.attn_L = MultiheadAttention(d_model, heads, dropout=p_drop)
if use_tied_col:
self.attn_N = TiedMultiheadAttention(d_model, heads, dropout=p_drop)
else:
if self.use_performer:
self.attn_N = SelfAttention(dim=d_model, heads=heads, dropout=p_drop,
generalized_attention=True, **performer_opts)
else:
self.attn_N = MultiheadAttention(d_model, heads, dropout=p_drop)
# feedforward
self.ff = FeedForwardLayer(d_model, d_ff, p_drop=p_drop)
# normalization module
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.norm3 = LayerNorm(d_model)
self.dropout1 = nn.Dropout(p_drop, inplace=True)
self.dropout2 = nn.Dropout(p_drop, inplace=True)
self.dropout3 = nn.Dropout(p_drop, inplace=True)
def forward(self, src, return_att=False):
# Input shape for multihead attention: (BATCH, NSEQ, NRES, EMB)
# Tied multihead attention w/ pre-LayerNorm
B, N, L = src.shape[:3]
src2 = self.norm1(src)
if self.use_tied_row or self.use_soft_row:
src2 = self.attn_L(src2, src2, src2) # Tied attention over L
else:
src2 = src2.reshape(B*N, L, -1)
src2 = self.attn_L(src2, src2, src2)
src2 = src2.reshape(B, N, L, -1)
src = src + self.dropout1(src2)
# attention over N
src2 = self.norm2(src)
if self.use_tied_col:
src2 = src2.permute(0,2,1,3)
src2 = self.attn_N(src2, src2, src2) # Tied attention over N
src2 = src2.permute(0,2,1,3)
else:
src2 = src2.permute(0,2,1,3).reshape(B*L, N, -1)
src2 = self.attn_N(src2, src2, src2) # attention over N
src2 = src2.reshape(B, L, N, -1).permute(0,2,1,3)
src = src + self.dropout2(src2)
# feed-forward
src2 = self.norm3(src) # pre-normalization
src2 = self.ff(src2)
src = src + self.dropout3(src2)
return src
class Encoder(nn.Module):
def __init__(self, enc_layer, n_layer):
super(Encoder, self).__init__()
self.layers = _get_clones(enc_layer, n_layer)
self.n_layer = n_layer
def forward(self, src, return_att=False):
output = src
for layer in self.layers:
output = layer(output, return_att=return_att)
return output
class CrossEncoderLayer(nn.Module):
def __init__(self, d_model, d_ff, heads, d_k, d_v, performer_opts=None, p_drop=0.1):
super(CrossEncoderLayer, self).__init__()
self.use_performer = performer_opts is not None
# multihead attention
if self.use_performer:
self.attn = SelfAttention(dim=d_model, k_dim=d_k, heads=heads, dropout=p_drop,
generalized_attention=True, **performer_opts)
else:
self.attn = MultiheadAttention(d_model, heads, k_dim=d_k, v_dim=d_v, dropout=p_drop)
# feedforward
self.ff = FeedForwardLayer(d_model, d_ff, p_drop=p_drop)
# normalization module
self.norm = LayerNorm(d_k)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.dropout1 = nn.Dropout(p_drop, inplace=True)
self.dropout2 = nn.Dropout(p_drop, inplace=True)
def forward(self, src, tgt):
# Input:
# For MSA to Pair: src (N, L, K), tgt (L, L, C)
# For Pair to MSA: src (L, L, C), tgt (N, L, K)
# Input shape for multihead attention: (SRCLEN, BATCH, EMB)
# multihead attention
# pre-normalization
src = self.norm(src)
tgt2 = self.norm1(tgt)
tgt2 = self.attn(tgt2, src, src) # projection to query, key, value are done in MultiheadAttention module
tgt = tgt + self.dropout1(tgt2)
# Feed forward
tgt2 = self.norm2(tgt)
tgt2 = self.ff(tgt2)
tgt = tgt + self.dropout2(tgt2)
return tgt
class DirectEncoderLayer(nn.Module):
def __init__(self, heads, d_in, d_out, d_ff, symmetrize=True, p_drop=0.1):
super(DirectEncoderLayer, self).__init__()
self.symmetrize = symmetrize
self.attn = DirectMultiheadAttention(d_in, d_out, heads, dropout=p_drop)
self.ff = FeedForwardLayer(d_out, d_ff, p_drop=p_drop)
# dropouts
self.drop_1 = nn.Dropout(p_drop, inplace=True)
self.drop_2 = nn.Dropout(p_drop, inplace=True)
# LayerNorm
self.norm = LayerNorm(d_in)
self.norm1 = LayerNorm(d_out)
self.norm2 = LayerNorm(d_out)
def forward(self, src, tgt):
# Input:
# For pair to msa: src=pair (B, L, L, C), tgt=msa (B, N, L, K)
B, N, L = tgt.shape[:3]
# get attention map
if self.symmetrize:
src = 0.5*(src + src.permute(0,2,1,3))
src = self.norm(src)
tgt2 = self.norm1(tgt)
tgt2 = self.attn(src, tgt2)
tgt = tgt + self.drop_1(tgt2)
# feed-forward
tgt2 = self.norm2(tgt.view(B*N,L,-1)).view(B,N,L,-1)
tgt2 = self.ff(tgt2)
tgt = tgt + self.drop_2(tgt2)
return tgt
class CrossEncoder(nn.Module):
def __init__(self, enc_layer, n_layer):
super(CrossEncoder, self).__init__()
self.layers = _get_clones(enc_layer, n_layer)
self.n_layer = n_layer
def forward(self, src, tgt):
output = tgt
for layer in self.layers:
output = layer(src, output)
return output

View file

@ -0,0 +1,163 @@
from utils.utils_profiling import * # load before other local modules
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from typing import Dict, List, Tuple
class Fiber(object):
"""A Handy Data Structure for Fibers"""
def __init__(self, num_degrees: int=None, num_channels: int=None,
structure: List[Tuple[int,int]]=None, dictionary=None):
"""
define fiber structure; use one num_degrees & num_channels OR structure
OR dictionary
:param num_degrees: degrees will be [0, ..., num_degrees-1]
:param num_channels: number of channels, same for each degree
:param structure: e.g. [(32, 0),(16, 1),(16,2)]
:param dictionary: e.g. {0:32, 1:16, 2:16}
"""
if structure:
self.structure = structure
elif dictionary:
self.structure = [(dictionary[o], o) for o in sorted(dictionary.keys())]
else:
self.structure = [(num_channels, i) for i in range(num_degrees)]
self.multiplicities, self.degrees = zip(*self.structure)
self.max_degree = max(self.degrees)
self.min_degree = min(self.degrees)
self.structure_dict = {k: v for v, k in self.structure}
self.dict = self.structure_dict
self.n_features = np.sum([i[0] * (2*i[1]+1) for i in self.structure])
self.feature_indices = {}
idx = 0
for (num_channels, d) in self.structure:
length = num_channels * (2*d + 1)
self.feature_indices[d] = (idx, idx + length)
idx += length
def copy_me(self, multiplicity: int=None):
s = copy.deepcopy(self.structure)
if multiplicity is not None:
# overwrite multiplicities
s = [(multiplicity, o) for m, o in s]
return Fiber(structure=s)
@staticmethod
def combine(f1, f2):
new_dict = copy.deepcopy(f1.structure_dict)
for k, m in f2.structure_dict.items():
if k in new_dict.keys():
new_dict[k] += m
else:
new_dict[k] = m
structure = [(new_dict[k], k) for k in sorted(new_dict.keys())]
return Fiber(structure=structure)
@staticmethod
def combine_max(f1, f2):
new_dict = copy.deepcopy(f1.structure_dict)
for k, m in f2.structure_dict.items():
if k in new_dict.keys():
new_dict[k] = max(m, new_dict[k])
else:
new_dict[k] = m
structure = [(new_dict[k], k) for k in sorted(new_dict.keys())]
return Fiber(structure=structure)
@staticmethod
def combine_selectively(f1, f2):
# only use orders which occur in fiber f1
new_dict = copy.deepcopy(f1.structure_dict)
for k in f1.degrees:
if k in f2.degrees:
new_dict[k] += f2.structure_dict[k]
structure = [(new_dict[k], k) for k in sorted(new_dict.keys())]
return Fiber(structure=structure)
@staticmethod
def combine_fibers(val1, struc1, val2, struc2):
"""
combine two fibers
:param val1/2: fiber tensors in dictionary form
:param struc1/2: structure of fiber
:return: fiber tensor in dictionary form
"""
struc_out = Fiber.combine(struc1, struc2)
val_out = {}
for k in struc_out.degrees:
if k in struc1.degrees:
if k in struc2.degrees:
val_out[k] = torch.cat([val1[k], val2[k]], -2)
else:
val_out[k] = val1[k]
else:
val_out[k] = val2[k]
assert val_out[k].shape[-2] == struc_out.structure_dict[k]
return val_out
def __repr__(self):
return f"{self.structure}"
def get_fiber_dict(F, struc, mask=None, return_struc=False):
if mask is None: mask = struc
index = 0
fiber_dict = {}
first_dims = F.shape[:-1]
masked_dict = {}
for o, m in struc.structure_dict.items():
length = m * (2*o + 1)
if o in mask.degrees:
masked_dict[o] = m
fiber_dict[o] = F[...,index:index + length].view(list(first_dims) + [m, 2*o + 1])
index += length
assert F.shape[-1] == index
if return_struc:
return fiber_dict, Fiber(dictionary=masked_dict)
return fiber_dict
def get_fiber_tensor(F, struc):
some_entry = tuple(F.values())[0]
first_dims = some_entry.shape[:-2]
res = some_entry.new_empty([*first_dims, struc.n_features])
index = 0
for o, m in struc.structure_dict.items():
length = m * (2*o + 1)
res[..., index: index + length] = F[o].view(*first_dims, length)
index += length
assert index == res.shape[-1]
return res
def fiber2tensor(F, structure, squeeze=False):
if squeeze:
fibers = [F[f'{i}'].view(*F[f'{i}'].shape[:-2], -1) for i in structure.degrees]
fibers = torch.cat(fibers, -1)
else:
fibers = [F[f'{i}'].view(*F[f'{i}'].shape[:-2], -1, 1) for i in structure.degrees]
fibers = torch.cat(fibers, -2)
return fibers
def fiber2head(F, h, structure, squeeze=False):
if squeeze:
fibers = [F[f'{i}'].view(*F[f'{i}'].shape[:-2], h, -1) for i in structure.degrees]
fibers = torch.cat(fibers, -1)
else:
fibers = [F[f'{i}'].view(*F[f'{i}'].shape[:-2], h, -1, 1) for i in structure.degrees]
fibers = torch.cat(fibers, -2)
return fibers

View file

@ -0,0 +1,289 @@
# pylint: disable=C,E1101,E1102
'''
Some functions related to SO3 and his usual representations
Using ZYZ Euler angles parametrisation
'''
import torch
import math
import numpy as np
class torch_default_dtype:
def __init__(self, dtype):
self.saved_dtype = None
self.dtype = dtype
def __enter__(self):
self.saved_dtype = torch.get_default_dtype()
torch.set_default_dtype(self.dtype)
def __exit__(self, exc_type, exc_value, traceback):
torch.set_default_dtype(self.saved_dtype)
def rot_z(gamma):
'''
Rotation around Z axis
'''
if not torch.is_tensor(gamma):
gamma = torch.tensor(gamma, dtype=torch.get_default_dtype())
return torch.tensor([
[torch.cos(gamma), -torch.sin(gamma), 0],
[torch.sin(gamma), torch.cos(gamma), 0],
[0, 0, 1]
], dtype=gamma.dtype)
def rot_y(beta):
'''
Rotation around Y axis
'''
if not torch.is_tensor(beta):
beta = torch.tensor(beta, dtype=torch.get_default_dtype())
return torch.tensor([
[torch.cos(beta), 0, torch.sin(beta)],
[0, 1, 0],
[-torch.sin(beta), 0, torch.cos(beta)]
], dtype=beta.dtype)
def rot(alpha, beta, gamma):
'''
ZYZ Eurler angles rotation
'''
return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)
def x_to_alpha_beta(x):
'''
Convert point (x, y, z) on the sphere into (alpha, beta)
'''
if not torch.is_tensor(x):
x = torch.tensor(x, dtype=torch.get_default_dtype())
x = x / torch.norm(x)
beta = torch.acos(x[2])
alpha = torch.atan2(x[1], x[0])
return (alpha, beta)
# These functions (x_to_alpha_beta and rot) satisfies that
# rot(*x_to_alpha_beta([x, y, z]), 0) @ np.array([[0], [0], [1]])
# is proportional to
# [x, y, z]
def irr_repr(order, alpha, beta, gamma, dtype=None):
"""
irreducible representation of SO3
- compatible with compose and spherical_harmonics
"""
# from from_lielearn_SO3.wigner_d import wigner_D_matrix
from lie_learn.representations.SO3.wigner_d import wigner_D_matrix
# if order == 1:
# # change of basis to have vector_field[x, y, z] = [vx, vy, vz]
# A = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
# return A @ wigner_D_matrix(1, alpha, beta, gamma) @ A.T
# TODO (non-essential): try to do everything in torch
# return torch.tensor(wigner_D_matrix(torch.tensor(order), alpha, beta, gamma), dtype=torch.get_default_dtype() if dtype is None else dtype)
return torch.tensor(wigner_D_matrix(order, np.array(alpha), np.array(beta), np.array(gamma)), dtype=torch.get_default_dtype() if dtype is None else dtype)
# def spherical_harmonics(order, alpha, beta, dtype=None):
# """
# spherical harmonics
# - compatible with irr_repr and compose
# """
# # from from_lielearn_SO3.spherical_harmonics import sh
# from lie_learn.representations.SO3.spherical_harmonics import sh # real valued by default
#
# ###################################################################################################################
# # ON ANGLE CONVENTION
# #
# # sh has following convention for angles:
# # :param theta: the colatitude / polar angle, ranging from 0(North Pole, (X, Y, Z) = (0, 0, 1)) to pi(South Pole, (X, Y, Z) = (0, 0, -1)).
# # :param phi: the longitude / azimuthal angle, ranging from 0 to 2 pi.
# #
# # this function therefore (probably) has the following convention for alpha and beta:
# # beta = pi - theta; ranging from 0(South Pole, (X, Y, Z) = (0, 0, -1)) to pi(North Pole, (X, Y, Z) = (0, 0, 1)).
# # alpha = phi
# #
# ###################################################################################################################
#
# Y = torch.tensor([sh(order, m, theta=math.pi - beta, phi=alpha) for m in range(-order, order + 1)], dtype=torch.get_default_dtype() if dtype is None else dtype)
# # if order == 1:
# # # change of basis to have vector_field[x, y, z] = [vx, vy, vz]
# # A = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
# # return A @ Y
# return Y
def compose(a1, b1, c1, a2, b2, c2):
"""
(a, b, c) = (a1, b1, c1) composed with (a2, b2, c2)
"""
comp = rot(a1, b1, c1) @ rot(a2, b2, c2)
xyz = comp @ torch.tensor([0, 0, 1.])
a, b = x_to_alpha_beta(xyz)
rotz = rot(0, -b, -a) @ comp
c = torch.atan2(rotz[1, 0], rotz[0, 0])
return a, b, c
def kron(x, y):
assert x.ndimension() == 2
assert y.ndimension() == 2
return torch.einsum("ij,kl->ikjl", (x, y)).view(x.size(0) * y.size(0), x.size(1) * y.size(1))
################################################################################
# Change of basis
################################################################################
def xyz_vector_basis_to_spherical_basis():
"""
to convert a vector [x, y, z] transforming with rot(a, b, c)
into a vector transforming with irr_repr(1, a, b, c)
see assert for usage
"""
with torch_default_dtype(torch.float64):
A = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float64)
assert all(torch.allclose(irr_repr(1, a, b, c) @ A, A @ rot(a, b, c)) for a, b, c in torch.rand(10, 3))
return A.type(torch.get_default_dtype())
def tensor3x3_repr(a, b, c):
"""
representation of 3x3 tensors
T --> R T R^t
"""
r = rot(a, b, c)
return kron(r, r)
def tensor3x3_repr_basis_to_spherical_basis():
"""
to convert a 3x3 tensor transforming with tensor3x3_repr(a, b, c)
into its 1 + 3 + 5 component transforming with irr_repr(0, a, b, c), irr_repr(1, a, b, c), irr_repr(3, a, b, c)
see assert for usage
"""
with torch_default_dtype(torch.float64):
to1 = torch.tensor([
[1, 0, 0, 0, 1, 0, 0, 0, 1],
], dtype=torch.get_default_dtype())
assert all(torch.allclose(irr_repr(0, a, b, c) @ to1, to1 @ tensor3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3))
to3 = torch.tensor([
[0, 0, -1, 0, 0, 0, 1, 0, 0],
[0, 1, 0, -1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, -1, 0],
], dtype=torch.get_default_dtype())
assert all(torch.allclose(irr_repr(1, a, b, c) @ to3, to3 @ tensor3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3))
to5 = torch.tensor([
[0, 1, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 1, 0],
[-3**.5/3, 0, 0, 0, -3**.5/3, 0, 0, 0, 12**.5/3],
[0, 0, 1, 0, 0, 0, 1, 0, 0],
[1, 0, 0, 0, -1, 0, 0, 0, 0]
], dtype=torch.get_default_dtype())
assert all(torch.allclose(irr_repr(2, a, b, c) @ to5, to5 @ tensor3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3))
return to1.type(torch.get_default_dtype()), to3.type(torch.get_default_dtype()), to5.type(torch.get_default_dtype())
################################################################################
# Tests
################################################################################
def test_is_representation(rep):
"""
rep(Z(a1) Y(b1) Z(c1) Z(a2) Y(b2) Z(c2)) = rep(Z(a1) Y(b1) Z(c1)) rep(Z(a2) Y(b2) Z(c2))
"""
with torch_default_dtype(torch.float64):
a1, b1, c1, a2, b2, c2 = torch.rand(6)
r1 = rep(a1, b1, c1)
r2 = rep(a2, b2, c2)
a, b, c = compose(a1, b1, c1, a2, b2, c2)
r = rep(a, b, c)
r_ = r1 @ r2
d, r = (r - r_).abs().max(), r.abs().max()
print(d.item(), r.item())
assert d < 1e-10 * r, d / r
def _test_spherical_harmonics(order):
"""
This test tests that
- irr_repr
- compose
- spherical_harmonics
are compatible
Y(Z(alpha) Y(beta) Z(gamma) x) = D(alpha, beta, gamma) Y(x)
with x = Z(a) Y(b) eta
"""
with torch_default_dtype(torch.float64):
a, b = torch.rand(2)
alpha, beta, gamma = torch.rand(3)
ra, rb, _ = compose(alpha, beta, gamma, a, b, 0)
Yrx = spherical_harmonics(order, ra, rb)
Y = spherical_harmonics(order, a, b)
DrY = irr_repr(order, alpha, beta, gamma) @ Y
d, r = (Yrx - DrY).abs().max(), Y.abs().max()
print(d.item(), r.item())
assert d < 1e-10 * r, d / r
def _test_change_basis_wigner_to_rot():
# from from_lielearn_SO3.wigner_d import wigner_D_matrix
from lie_learn.representations.SO3.wigner_d import wigner_D_matrix
with torch_default_dtype(torch.float64):
A = torch.tensor([
[0, 1, 0],
[0, 0, 1],
[1, 0, 0]
], dtype=torch.float64)
a, b, c = torch.rand(3)
r1 = A.t() @ torch.tensor(wigner_D_matrix(1, a, b, c), dtype=torch.float64) @ A
r2 = rot(a, b, c)
d = (r1 - r2).abs().max()
print(d.item())
assert d < 1e-10
if __name__ == "__main__":
from functools import partial
print("Change of basis")
xyz_vector_basis_to_spherical_basis()
test_is_representation(tensor3x3_repr)
tensor3x3_repr_basis_to_spherical_basis()
print("Change of basis Wigner <-> rot")
_test_change_basis_wigner_to_rot()
_test_change_basis_wigner_to_rot()
_test_change_basis_wigner_to_rot()
print("Spherical harmonics are solution of Y(rx) = D(r) Y(x)")
for l in range(7):
_test_spherical_harmonics(l)
print("Irreducible repr are indeed representations")
for l in range(7):
test_is_representation(partial(irr_repr, l))

View file

@ -0,0 +1,112 @@
'''
Cache in files
'''
from functools import wraps, lru_cache
import pickle
import gzip
import os
import sys
import fcntl
class FileSystemMutex:
'''
Mutual exclusion of different **processes** using the file system
'''
def __init__(self, filename):
self.handle = None
self.filename = filename
def acquire(self):
'''
Locks the mutex
if it is already locked, it waits (blocking function)
'''
self.handle = open(self.filename, 'w')
fcntl.lockf(self.handle, fcntl.LOCK_EX)
self.handle.write("{}\n".format(os.getpid()))
self.handle.flush()
def release(self):
'''
Unlock the mutex
'''
if self.handle is None:
raise RuntimeError()
fcntl.lockf(self.handle, fcntl.LOCK_UN)
self.handle.close()
self.handle = None
def __enter__(self):
self.acquire()
def __exit__(self, exc_type, exc_value, traceback):
self.release()
def cached_dirpklgz(dirname, maxsize=128):
'''
Cache a function with a directory
:param dirname: the directory path
:param maxsize: maximum size of the RAM cache (there is no limit for the directory cache)
'''
def decorator(func):
'''
The actual decorator
'''
@lru_cache(maxsize=maxsize)
@wraps(func)
def wrapper(*args, **kwargs):
'''
The wrapper of the function
'''
try:
os.makedirs(dirname)
except FileExistsError:
pass
indexfile = os.path.join(dirname, "index.pkl")
mutexfile = "cache_mutex"
#mutexfile = os.path.join(dirname, "mutex")
with FileSystemMutex(mutexfile):
try:
with open(indexfile, "rb") as file:
index = pickle.load(file)
except FileNotFoundError:
index = {}
key = (args, frozenset(kwargs), func.__defaults__)
try:
filename = index[key]
except KeyError:
index[key] = filename = "{}.pkl.gz".format(len(index))
with open(indexfile, "wb") as file:
pickle.dump(index, file)
filepath = os.path.join(dirname, filename)
try:
with FileSystemMutex(mutexfile):
with gzip.open(filepath, "rb") as file:
result = pickle.load(file)
except FileNotFoundError:
print("compute {}... ".format(filename), end="")
sys.stdout.flush()
result = func(*args, **kwargs)
print("save {}... ".format(filename), end="")
sys.stdout.flush()
with FileSystemMutex(mutexfile):
with gzip.open(filepath, "wb") as file:
pickle.dump(result, file)
print("done")
return result
return wrapper
return decorator

View file

@ -0,0 +1,25 @@
the code in this folder was mostly obtained from https://github.com/mariogeiger/se3cnn/
which has the following license:
MIT License
Copyright (c) 2019 Mario Geiger
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,249 @@
import time
import torch
import numpy as np
from scipy.special import lpmv as lpmv_scipy
def semifactorial(x):
"""Compute the semifactorial function x!!.
x!! = x * (x-2) * (x-4) *...
Args:
x: positive int
Returns:
float for x!!
"""
y = 1.
for n in range(x, 1, -2):
y *= n
return y
def pochhammer(x, k):
"""Compute the pochhammer symbol (x)_k.
(x)_k = x * (x+1) * (x+2) *...* (x+k-1)
Args:
x: positive int
Returns:
float for (x)_k
"""
xf = float(x)
for n in range(x+1, x+k):
xf *= n
return xf
def lpmv(l, m, x):
"""Associated Legendre function including Condon-Shortley phase.
Args:
m: int order
l: int degree
x: float argument tensor
Returns:
tensor of x-shape
"""
m_abs = abs(m)
if m_abs > l:
return torch.zeros_like(x)
# Compute P_m^m
yold = ((-1)**m_abs * semifactorial(2*m_abs-1)) * torch.pow(1-x*x, m_abs/2)
# Compute P_{m+1}^m
if m_abs != l:
y = x * (2*m_abs+1) * yold
else:
y = yold
# Compute P_{l}^m from recursion in P_{l-1}^m and P_{l-2}^m
for i in range(m_abs+2, l+1):
tmp = y
# Inplace speedup
y = ((2*i-1) / (i-m_abs)) * x * y
y -= ((i+m_abs-1)/(i-m_abs)) * yold
yold = tmp
if m < 0:
y *= ((-1)**m / pochhammer(l+m+1, -2*m))
return y
def tesseral_harmonics(l, m, theta=0., phi=0.):
"""Tesseral spherical harmonic with Condon-Shortley phase.
The Tesseral spherical harmonics are also known as the real spherical
harmonics.
Args:
l: int for degree
m: int for order, where -l <= m < l
theta: collatitude or polar angle
phi: longitude or azimuth
Returns:
tensor of shape theta
"""
assert abs(m) <= l, "absolute value of order m must be <= degree l"
N = np.sqrt((2*l+1) / (4*np.pi))
leg = lpmv(l, abs(m), torch.cos(theta))
if m == 0:
return N*leg
elif m > 0:
Y = torch.cos(m*phi) * leg
else:
Y = torch.sin(abs(m)*phi) * leg
N *= np.sqrt(2. / pochhammer(l-abs(m)+1, 2*abs(m)))
Y *= N
return Y
class SphericalHarmonics(object):
def __init__(self):
self.leg = {}
def clear(self):
self.leg = {}
def negative_lpmv(self, l, m, y):
"""Compute negative order coefficients"""
if m < 0:
y *= ((-1)**m / pochhammer(l+m+1, -2*m))
return y
def lpmv(self, l, m, x):
"""Associated Legendre function including Condon-Shortley phase.
Args:
m: int order
l: int degree
x: float argument tensor
Returns:
tensor of x-shape
"""
# Check memoized versions
m_abs = abs(m)
if (l,m) in self.leg:
return self.leg[(l,m)]
elif m_abs > l:
return None
elif l == 0:
self.leg[(l,m)] = torch.ones_like(x)
return self.leg[(l,m)]
# Check if on boundary else recurse solution down to boundary
if m_abs == l:
# Compute P_m^m
y = (-1)**m_abs * semifactorial(2*m_abs-1)
y *= torch.pow(1-x*x, m_abs/2)
self.leg[(l,m)] = self.negative_lpmv(l, m, y)
return self.leg[(l,m)]
else:
# Recursively precompute lower degree harmonics
self.lpmv(l-1, m, x)
# Compute P_{l}^m from recursion in P_{l-1}^m and P_{l-2}^m
# Inplace speedup
y = ((2*l-1) / (l-m_abs)) * x * self.lpmv(l-1, m_abs, x)
if l - m_abs > 1:
y -= ((l+m_abs-1)/(l-m_abs)) * self.leg[(l-2, m_abs)]
#self.leg[(l, m_abs)] = y
if m < 0:
y = self.negative_lpmv(l, m, y)
self.leg[(l,m)] = y
return self.leg[(l,m)]
def get_element(self, l, m, theta, phi):
"""Tesseral spherical harmonic with Condon-Shortley phase.
The Tesseral spherical harmonics are also known as the real spherical
harmonics.
Args:
l: int for degree
m: int for order, where -l <= m < l
theta: collatitude or polar angle
phi: longitude or azimuth
Returns:
tensor of shape theta
"""
assert abs(m) <= l, "absolute value of order m must be <= degree l"
N = np.sqrt((2*l+1) / (4*np.pi))
leg = self.lpmv(l, abs(m), torch.cos(theta))
if m == 0:
return N*leg
elif m > 0:
Y = torch.cos(m*phi) * leg
else:
Y = torch.sin(abs(m)*phi) * leg
N *= np.sqrt(2. / pochhammer(l-abs(m)+1, 2*abs(m)))
Y *= N
return Y
def get(self, l, theta, phi, refresh=True):
"""Tesseral harmonic with Condon-Shortley phase.
The Tesseral spherical harmonics are also known as the real spherical
harmonics.
Args:
l: int for degree
theta: collatitude or polar angle
phi: longitude or azimuth
Returns:
tensor of shape [*theta.shape, 2*l+1]
"""
results = []
if refresh:
self.clear()
for m in range(-l, l+1):
results.append(self.get_element(l, m, theta, phi))
return torch.stack(results, -1)
if __name__ == "__main__":
from lie_learn.representations.SO3.spherical_harmonics import sh
device = 'cuda'
dtype = torch.float64
bs = 32
theta = 0.1*torch.randn(bs,1024,10, dtype=dtype)
phi = 0.1*torch.randn(bs,1024,10, dtype=dtype)
cu_theta = theta.to(device)
cu_phi = phi.to(device)
s0 = s1 = s2 = 0
max_error = -1.
sph_har = SphericalHarmonics()
for l in range(10):
for m in range(l, -l-1, -1):
start = time.time()
#y = tesseral_harmonics(l, m, theta, phi)
y = sph_har.get_element(l, m, cu_theta, cu_phi).type(torch.float32)
#y = sph_har.lpmv(l, m, phi)
s0 += time.time() - start
start = time.time()
z = sh(l, m, theta, phi)
#z = lpmv_scipy(m, l, phi).numpy()
s1 += time.time() - start
error = np.mean(np.abs((y.cpu().numpy() - z) / z))
max_error = max(max_error, error)
print(f"l: {l}, m: {m} ", error)
#start = time.time()
#sph_har.get(l, theta, phi)
#s2 += time.time() - start
print('#################')
print(f"Max error: {max_error}")
print(f"Time diff: {s0/s1}")
print(f"Total time: {s0}")
#print(f"Time diff: {s2/s1}")

View file

@ -0,0 +1,326 @@
import os
import torch
import math
import numpy as np
from equivariant_attention.from_se3cnn.SO3 import irr_repr, torch_default_dtype
from equivariant_attention.from_se3cnn.cache_file import cached_dirpklgz
from equivariant_attention.from_se3cnn.representations import SphericalHarmonics
################################################################################
# Solving the constraint coming from the stabilizer of 0 and e
################################################################################
def get_matrix_kernel(A, eps=1e-10):
'''
Compute an orthonormal basis of the kernel (x_1, x_2, ...)
A x_i = 0
scalar_product(x_i, x_j) = delta_ij
:param A: matrix
:return: matrix where each row is a basis vector of the kernel of A
'''
_u, s, v = torch.svd(A)
# A = u @ torch.diag(s) @ v.t()
kernel = v.t()[s < eps]
return kernel
def get_matrices_kernel(As, eps=1e-10):
'''
Computes the commun kernel of all the As matrices
'''
return get_matrix_kernel(torch.cat(As, dim=0), eps)
@cached_dirpklgz("%s/cache/trans_Q"%os.path.dirname(os.path.realpath(__file__)))
def _basis_transformation_Q_J(J, order_in, order_out, version=3): # pylint: disable=W0613
"""
:param J: order of the spherical harmonics
:param order_in: order of the input representation
:param order_out: order of the output representation
:return: one part of the Q^-1 matrix of the article
"""
with torch_default_dtype(torch.float64):
def _R_tensor(a, b, c): return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c))
def _sylvester_submatrix(J, a, b, c):
''' generate Kronecker product matrix for solving the Sylvester equation in subspace J '''
R_tensor = _R_tensor(a, b, c) # [m_out * m_in, m_out * m_in]
R_irrep_J = irr_repr(J, a, b, c) # [m, m]
return kron(R_tensor, torch.eye(R_irrep_J.size(0))) - \
kron(torch.eye(R_tensor.size(0)), R_irrep_J.t()) # [(m_out * m_in) * m, (m_out * m_in) * m]
random_angles = [
[4.41301023, 5.56684102, 4.59384642],
[4.93325116, 6.12697327, 4.14574096],
[0.53878964, 4.09050444, 5.36539036],
[2.16017393, 3.48835314, 5.55174441],
[2.52385107, 0.2908958, 3.90040975]
]
null_space = get_matrices_kernel([_sylvester_submatrix(J, a, b, c) for a, b, c in random_angles])
assert null_space.size(0) == 1, null_space.size() # unique subspace solution
Q_J = null_space[0] # [(m_out * m_in) * m]
Q_J = Q_J.view((2 * order_out + 1) * (2 * order_in + 1), 2 * J + 1) # [m_out * m_in, m]
assert all(torch.allclose(_R_tensor(a, b, c) @ Q_J, Q_J @ irr_repr(J, a, b, c)) for a, b, c in torch.rand(4, 3))
assert Q_J.dtype == torch.float64
return Q_J # [m_out * m_in, m]
def get_spherical_from_cartesian_torch(cartesian, divide_radius_by=1.0):
###################################################################################################################
# ON ANGLE CONVENTION
#
# sh has following convention for angles:
# :param theta: the colatitude / polar angle, ranging from 0(North Pole, (X, Y, Z) = (0, 0, 1)) to pi(South Pole, (X, Y, Z) = (0, 0, -1)).
# :param phi: the longitude / azimuthal angle, ranging from 0 to 2 pi.
#
# the 3D steerable CNN code therefore (probably) has the following convention for alpha and beta:
# beta = pi - theta; ranging from 0(South Pole, (X, Y, Z) = (0, 0, -1)) to pi(North Pole, (X, Y, Z) = (0, 0, 1)).
# alpha = phi
#
###################################################################################################################
# initialise return array
# ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
spherical = torch.zeros_like(cartesian)
# indices for return array
ind_radius = 0
ind_alpha = 1
ind_beta = 2
cartesian_x = 2
cartesian_y = 0
cartesian_z = 1
# get projected radius in xy plane
# xy = xyz[:,0]**2 + xyz[:,1]**2
r_xy = cartesian[..., cartesian_x] ** 2 + cartesian[..., cartesian_y] ** 2
# get second angle
# version 'elevation angle defined from Z-axis down'
spherical[..., ind_beta] = torch.atan2(torch.sqrt(r_xy), cartesian[..., cartesian_z])
# ptsnew[:,4] = np.arctan2(np.sqrt(xy), xyz[:,2])
# version 'elevation angle defined from XY-plane up'
#ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy))
# spherical[:, ind_beta] = np.arctan2(cartesian[:, 2], np.sqrt(r_xy))
# get angle in x-y plane
spherical[...,ind_alpha] = torch.atan2(cartesian[...,cartesian_y], cartesian[...,cartesian_x])
# get overall radius
# ptsnew[:,3] = np.sqrt(xy + xyz[:,2]**2)
if divide_radius_by == 1.0:
spherical[..., ind_radius] = torch.sqrt(r_xy + cartesian[...,cartesian_z]**2)
else:
spherical[..., ind_radius] = torch.sqrt(r_xy + cartesian[...,cartesian_z]**2)/divide_radius_by
return spherical
def get_spherical_from_cartesian(cartesian):
###################################################################################################################
# ON ANGLE CONVENTION
#
# sh has following convention for angles:
# :param theta: the colatitude / polar angle, ranging from 0(North Pole, (X, Y, Z) = (0, 0, 1)) to pi(South Pole, (X, Y, Z) = (0, 0, -1)).
# :param phi: the longitude / azimuthal angle, ranging from 0 to 2 pi.
#
# the 3D steerable CNN code therefore (probably) has the following convention for alpha and beta:
# beta = pi - theta; ranging from 0(South Pole, (X, Y, Z) = (0, 0, -1)) to pi(North Pole, (X, Y, Z) = (0, 0, 1)).
# alpha = phi
#
###################################################################################################################
if torch.is_tensor(cartesian):
cartesian = np.array(cartesian.cpu())
# initialise return array
# ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
spherical = np.zeros(cartesian.shape)
# indices for return array
ind_radius = 0
ind_alpha = 1
ind_beta = 2
cartesian_x = 2
cartesian_y = 0
cartesian_z = 1
# get projected radius in xy plane
# xy = xyz[:,0]**2 + xyz[:,1]**2
r_xy = cartesian[..., cartesian_x] ** 2 + cartesian[..., cartesian_y] ** 2
# get overall radius
# ptsnew[:,3] = np.sqrt(xy + xyz[:,2]**2)
spherical[..., ind_radius] = np.sqrt(r_xy + cartesian[...,cartesian_z]**2)
# get second angle
# version 'elevation angle defined from Z-axis down'
spherical[..., ind_beta] = np.arctan2(np.sqrt(r_xy), cartesian[..., cartesian_z])
# ptsnew[:,4] = np.arctan2(np.sqrt(xy), xyz[:,2])
# version 'elevation angle defined from XY-plane up'
#ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy))
# spherical[:, ind_beta] = np.arctan2(cartesian[:, 2], np.sqrt(r_xy))
# get angle in x-y plane
spherical[...,ind_alpha] = np.arctan2(cartesian[...,cartesian_y], cartesian[...,cartesian_x])
return spherical
def test_coordinate_conversion():
p = np.array([0, 0, -1])
expected = np.array([1, 0, 0])
assert get_spherical_from_cartesian(p) == expected
return True
def spherical_harmonics(order, alpha, beta, dtype=None):
"""
spherical harmonics
- compatible with irr_repr and compose
computation time: excecuting 1000 times with array length 1 took 0.29 seconds;
executing it once with array of length 1000 took 0.0022 seconds
"""
#Y = [tesseral_harmonics(order, m, theta=math.pi - beta, phi=alpha) for m in range(-order, order + 1)]
#Y = torch.stack(Y, -1)
# Y should have dimension 2*order + 1
return SphericalHarmonics.get(order, theta=math.pi-beta, phi=alpha)
def kron(a, b):
"""
A part of the pylabyk library: numpytorch.py at https://github.com/yulkang/pylabyk
Kronecker product of matrices a and b with leading batch dimensions.
Batch dimensions are broadcast. The number of them mush
:type a: torch.Tensor
:type b: torch.Tensor
:rtype: torch.Tensor
"""
siz1 = torch.Size(torch.tensor(a.shape[-2:]) * torch.tensor(b.shape[-2:]))
res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4)
siz0 = res.shape[:-4]
return res.reshape(siz0 + siz1)
def get_maximum_order_unary_only(per_layer_orders_and_multiplicities):
"""
determine what spherical harmonics we need to pre-compute. if we have the
unary term only, we need to compare all adjacent layers
the spherical harmonics function depends on J (irrep order) purely, which is dedfined by
order_irreps = list(range(abs(order_in - order_out), order_in + order_out + 1))
simplification: we only care about the maximum (in some circumstances that means we calculate a few lower
order spherical harmonics which we won't actually need)
:param per_layer_orders_and_multiplicities: nested list of lists of 2-tuples
:return: integer indicating maximum order J
"""
n_layers = len(per_layer_orders_and_multiplicities)
# extract orders only
per_layer_orders = []
for i in range(n_layers):
cur = per_layer_orders_and_multiplicities[i]
cur = [o for (m, o) in cur]
per_layer_orders.append(cur)
track_max = 0
# compare two (adjacent) layers at a time
for i in range(n_layers - 1):
cur = per_layer_orders[i]
nex = per_layer_orders[i + 1]
track_max = max(max(cur) + max(nex), track_max)
return track_max
def get_maximum_order_with_pairwise(per_layer_orders_and_multiplicities):
"""
determine what spherical harmonics we need to pre-compute. for pairwise
interactions, this will just be twice the maximum order
the spherical harmonics function depends on J (irrep order) purely, which is defined by
order_irreps = list(range(abs(order_in - order_out), order_in + order_out + 1))
simplification: we only care about the maximum (in some circumstances that means we calculate a few lower
order spherical harmonics which we won't actually need)
:param per_layer_orders_and_multiplicities: nested list of lists of 2-tuples
:return: integer indicating maximum order J
"""
n_layers = len(per_layer_orders_and_multiplicities)
track_max = 0
for i in range(n_layers):
cur = per_layer_orders_and_multiplicities[i]
# extract orders only
orders = [o for (m, o) in cur]
track_max = max(track_max, max(orders))
return 2*track_max
def precompute_sh(r_ij, max_J):
"""
pre-comput spherical harmonics up to order max_J
:param r_ij: relative positions
:param max_J: maximum order used in entire network
:return: dict where each entry has shape [B,N,K,2J+1]
"""
i_distance = 0
i_alpha = 1
i_beta = 2
Y_Js = {}
sh = SphericalHarmonics()
for J in range(max_J+1):
# dimension [B,N,K,2J+1]
#Y_Js[J] = spherical_harmonics(order=J, alpha=r_ij[...,i_alpha], beta=r_ij[...,i_beta])
Y_Js[J] = sh.get(J, theta=math.pi-r_ij[...,i_beta], phi=r_ij[...,i_alpha], refresh=False)
sh.clear()
return Y_Js
class ScalarActivation3rdDim(torch.nn.Module):
def __init__(self, n_dim, activation, bias=True):
'''
Can be used only with scalar fields [B, N, s] on last dimension
:param n_dim: number of scalar fields to apply activation to
:param bool bias: add a bias before the applying the activation
'''
super().__init__()
self.activation = activation
if bias and n_dim > 0:
self.bias = torch.nn.Parameter(torch.zeros(n_dim))
else:
self.bias = None
def forward(self, input):
'''
:param input: [B, N, s]
'''
assert len(np.array(input.shape)) == 3
if self.bias is not None:
x = input + self.bias.view(1, 1, -1)
else:
x = input
x = self.activation(x)
return x

View file

@ -0,0 +1,905 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from contextlib import nullcontext
from typing import Dict
from equivariant_attention.from_se3cnn import utils_steerable
from equivariant_attention.fibers import Fiber, fiber2head
from utils.utils_logging import log_gradient_norm
import dgl
import dgl.function as fn
from dgl.nn.pytorch.softmax import edge_softmax
from dgl.nn.pytorch.glob import AvgPooling, MaxPooling
from packaging import version
def get_basis(G, max_degree, compute_gradients):
"""Precompute the SE(3)-equivariant weight basis, W_J^lk(x)
This is called by get_basis_and_r().
Args:
G: DGL graph instance of type dgl.DGLGraph
max_degree: non-negative int for degree of highest feature type
compute_gradients: boolean, whether to compute gradients during basis construction
Returns:
dict of equivariant bases. Keys are in the form 'd_in,d_out'. Values are
tensors of shape (batch_size, 1, 2*d_out+1, 1, 2*d_in+1, number_of_bases)
where the 1's will later be broadcast to the number of output and input
channels
"""
if compute_gradients:
context = nullcontext()
else:
context = torch.no_grad()
with context:
cloned_d = torch.clone(G.edata['d'])
if G.edata['d'].requires_grad:
cloned_d.requires_grad_()
log_gradient_norm(cloned_d, 'Basis computation flow')
# Relative positional encodings (vector)
r_ij = utils_steerable.get_spherical_from_cartesian_torch(cloned_d)
# Spherical harmonic basis
Y = utils_steerable.precompute_sh(r_ij, 2*max_degree)
device = Y[0].device
basis = {}
for d_in in range(max_degree+1):
for d_out in range(max_degree+1):
K_Js = []
for J in range(abs(d_in-d_out), d_in+d_out+1):
# Get spherical harmonic projection matrices
Q_J = utils_steerable._basis_transformation_Q_J(J, d_in, d_out)
Q_J = Q_J.float().T.to(device)
# Create kernel from spherical harmonics
K_J = torch.matmul(Y[J], Q_J)
K_Js.append(K_J)
# Reshape so can take linear combinations with a dot product
size = (-1, 1, 2*d_out+1, 1, 2*d_in+1, 2*min(d_in, d_out)+1)
basis[f'{d_in},{d_out}'] = torch.stack(K_Js, -1).view(*size)
return basis
def get_r(G):
"""Compute internodal distances"""
cloned_d = torch.clone(G.edata['d'])
if G.edata['d'].requires_grad:
cloned_d.requires_grad_()
log_gradient_norm(cloned_d, 'Neural networks flow')
return torch.sqrt(torch.sum(cloned_d**2, -1, keepdim=True))
def get_basis_and_r(G, max_degree, compute_gradients=False):
"""Return equivariant weight basis (basis) and internodal distances (r).
Call this function *once* at the start of each forward pass of the model.
It computes the equivariant weight basis, W_J^lk(x), and internodal
distances, needed to compute varphi_J^lk(x), of eqn 8 of
https://arxiv.org/pdf/2006.10503.pdf. The return values of this function
can be shared as input across all SE(3)-Transformer layers in a model.
Args:
G: DGL graph instance of type dgl.DGLGraph()
max_degree: non-negative int for degree of highest feature-type
compute_gradients: controls whether to compute gradients during basis construction
Returns:
dict of equivariant bases, keys are in form '<d_in><d_out>'
vector of relative distances, ordered according to edge ordering of G
"""
basis = get_basis(G, max_degree, compute_gradients)
r = get_r(G)
return basis, r
### SE(3) equivariant operations on graphs in DGL
class GConvSE3(nn.Module):
"""A tensor field network layer as a DGL module.
GConvSE3 stands for a Graph Convolution SE(3)-equivariant layer. It is the
equivalent of a linear layer in an MLP, a conv layer in a CNN, or a graph
conv layer in a GCN.
At each node, the activations are split into different "feature types",
indexed by the SE(3) representation type: non-negative integers 0, 1, 2, ..
"""
def __init__(self, f_in, f_out, self_interaction: bool=False, edge_dim: int=0, flavor='skip'):
"""SE(3)-equivariant Graph Conv Layer
Args:
f_in: list of tuples [(multiplicities, type),...]
f_out: list of tuples [(multiplicities, type),...]
self_interaction: include self-interaction in convolution
edge_dim: number of dimensions for edge embedding
flavor: allows ['TFN', 'skip'], where 'skip' adds a skip connection
"""
super().__init__()
self.f_in = f_in
self.f_out = f_out
self.edge_dim = edge_dim
self.self_interaction = self_interaction
self.flavor = flavor
# Neighbor -> center weights
self.kernel_unary = nn.ModuleDict()
for (mi, di) in self.f_in.structure:
for (mo, do) in self.f_out.structure:
self.kernel_unary[f'({di},{do})'] = PairwiseConv(di, mi, do, mo, edge_dim=edge_dim)
# Center -> center weights
self.kernel_self = nn.ParameterDict()
if self_interaction:
assert self.flavor in ['TFN', 'skip']
if self.flavor == 'TFN':
for m_out, d_out in self.f_out.structure:
W = nn.Parameter(torch.randn(1, m_out, m_out) / np.sqrt(m_out))
self.kernel_self[f'{d_out}'] = W
elif self.flavor == 'skip':
for m_in, d_in in self.f_in.structure:
if d_in in self.f_out.degrees:
m_out = self.f_out.structure_dict[d_in]
W = nn.Parameter(torch.randn(1, m_out, m_in) / np.sqrt(m_in))
self.kernel_self[f'{d_in}'] = W
def __repr__(self):
return f'GConvSE3(structure={self.f_out}, self_interaction={self.self_interaction})'
def udf_u_mul_e(self, d_out):
"""Compute the convolution for a single output feature type.
This function is set up as a User Defined Function in DGL.
Args:
d_out: output feature type
Returns:
edge -> node function handle
"""
def fnc(edges):
# Neighbor -> center messages
msg = 0
for m_in, d_in in self.f_in.structure:
src = edges.src[f'{d_in}'].view(-1, m_in*(2*d_in+1), 1)
edge = edges.data[f'({d_in},{d_out})']
msg = msg + torch.matmul(edge, src)
msg = msg.view(msg.shape[0], -1, 2*d_out+1)
# Center -> center messages
if self.self_interaction:
if f'{d_out}' in self.kernel_self.keys():
if self.flavor == 'TFN':
W = self.kernel_self[f'{d_out}']
msg = torch.matmul(W, msg)
if self.flavor == 'skip':
dst = edges.dst[f'{d_out}']
W = self.kernel_self[f'{d_out}']
msg = msg + torch.matmul(W, dst)
return {'msg': msg.view(msg.shape[0], -1, 2*d_out+1)}
return fnc
def forward(self, h, G=None, r=None, basis=None, **kwargs):
"""Forward pass of the linear layer
Args:
G: minibatch of (homo)graphs
h: dict of features
r: inter-atomic distances
basis: pre-computed Q * Y
Returns:
tensor with new features [B, n_points, n_features_out]
"""
with G.local_scope():
# Add node features to local graph scope
for k, v in h.items():
G.ndata[k] = v
# Add edge features
if 'w' in G.edata.keys():
w = G.edata['w']
feat = torch.cat([w, r], -1)
else:
feat = torch.cat([r, ], -1)
for (mi, di) in self.f_in.structure:
for (mo, do) in self.f_out.structure:
etype = f'({di},{do})'
G.edata[etype] = self.kernel_unary[etype](feat, basis)
# Perform message-passing for each output feature type
for d in self.f_out.degrees:
G.update_all(self.udf_u_mul_e(d), fn.mean('msg', f'out{d}'))
return {f'{d}': G.ndata[f'out{d}'] for d in self.f_out.degrees}
class RadialFunc(nn.Module):
"""NN parameterized radial profile function."""
def __init__(self, num_freq, in_dim, out_dim, edge_dim: int=0):
"""NN parameterized radial profile function.
Args:
num_freq: number of output frequencies
in_dim: multiplicity of input (num input channels)
out_dim: multiplicity of output (num output channels)
edge_dim: number of dimensions for edge embedding
"""
super().__init__()
self.num_freq = num_freq
self.in_dim = in_dim
self.mid_dim = 32
self.out_dim = out_dim
self.edge_dim = edge_dim
self.net = nn.Sequential(nn.Linear(self.edge_dim+1,self.mid_dim),
BN(self.mid_dim),
nn.ReLU(),
nn.Linear(self.mid_dim,self.mid_dim),
BN(self.mid_dim),
nn.ReLU(),
nn.Linear(self.mid_dim,self.num_freq*in_dim*out_dim))
nn.init.kaiming_uniform_(self.net[0].weight)
nn.init.kaiming_uniform_(self.net[3].weight)
nn.init.kaiming_uniform_(self.net[6].weight)
def __repr__(self):
return f"RadialFunc(edge_dim={self.edge_dim}, in_dim={self.in_dim}, out_dim={self.out_dim})"
def forward(self, x):
y = self.net(x)
return y.view(-1, self.out_dim, 1, self.in_dim, 1, self.num_freq)
class PairwiseConv(nn.Module):
"""SE(3)-equivariant convolution between two single-type features"""
def __init__(self, degree_in: int, nc_in: int, degree_out: int,
nc_out: int, edge_dim: int=0):
"""SE(3)-equivariant convolution between a pair of feature types.
This layer performs a convolution from nc_in features of type degree_in
to nc_out features of type degree_out.
Args:
degree_in: degree of input fiber
nc_in: number of channels on input
degree_out: degree of out order
nc_out: number of channels on output
edge_dim: number of dimensions for edge embedding
"""
super().__init__()
# Log settings
self.degree_in = degree_in
self.degree_out = degree_out
self.nc_in = nc_in
self.nc_out = nc_out
# Functions of the degree
self.num_freq = 2*min(degree_in, degree_out) + 1
self.d_out = 2*degree_out + 1
self.edge_dim = edge_dim
# Radial profile function
self.rp = RadialFunc(self.num_freq, nc_in, nc_out, self.edge_dim)
def forward(self, feat, basis):
# Get radial weights
R = self.rp(feat)
kernel = torch.sum(R * basis[f'{self.degree_in},{self.degree_out}'], -1)
return kernel.view(kernel.shape[0], self.d_out*self.nc_out, -1)
class G1x1SE3(nn.Module):
"""Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.
This is equivalent to a self-interaction layer in TensorField Networks.
"""
def __init__(self, f_in, f_out, learnable=True):
"""SE(3)-equivariant 1x1 convolution.
Args:
f_in: input Fiber() of feature multiplicities and types
f_out: output Fiber() of feature multiplicities and types
"""
super().__init__()
self.f_in = f_in
self.f_out = f_out
# Linear mappings: 1 per output feature type
self.transform = nn.ParameterDict()
for m_out, d_out in self.f_out.structure:
m_in = self.f_in.structure_dict[d_out]
self.transform[str(d_out)] = nn.Parameter(torch.randn(m_out, m_in) / np.sqrt(m_in), requires_grad=learnable)
def __repr__(self):
return f"G1x1SE3(structure={self.f_out})"
def forward(self, features, **kwargs):
output = {}
for k, v in features.items():
if str(k) in self.transform.keys():
output[k] = torch.matmul(self.transform[str(k)], v)
return output
class GNormBias(nn.Module):
"""Norm-based SE(3)-equivariant nonlinearity with only learned biases."""
def __init__(self, fiber, nonlin=nn.ReLU(inplace=True),
num_layers: int = 0):
"""Initializer.
Args:
fiber: Fiber() of feature multiplicities and types
nonlin: nonlinearity to use everywhere
num_layers: non-negative number of linear layers in fnc
"""
super().__init__()
self.fiber = fiber
self.nonlin = nonlin
self.num_layers = num_layers
# Regularization for computing phase: gradients explode otherwise
self.eps = 1e-12
# Norm mappings: 1 per feature type
self.bias = nn.ParameterDict()
for m, d in self.fiber.structure:
self.bias[str(d)] = nn.Parameter(torch.randn(m).view(1, m))
def __repr__(self):
return f"GNormTFN()"
def forward(self, features, **kwargs):
output = {}
for k, v in features.items():
# Compute the norms and normalized features
# v shape: [...,m , 2*k+1]
norm = v.norm(2, -1, keepdim=True).clamp_min(self.eps).expand_as(v)
phase = v / norm
# Transform on norms
# transformed = self.transform[str(k)](norm[..., 0]).unsqueeze(-1)
transformed = self.nonlin(norm[..., 0] + self.bias[str(k)])
# Nonlinearity on norm
output[k] = (transformed.unsqueeze(-1) * phase).view(*v.shape)
return output
class GAttentiveSelfInt(nn.Module):
def __init__(self, f_in, f_out):
"""SE(3)-equivariant 1x1 convolution.
Args:
f_in: input Fiber() of feature multiplicities and types
f_out: output Fiber() of feature multiplicities and types
"""
super().__init__()
self.f_in = f_in
self.f_out = f_out
self.nonlin = nn.LeakyReLU()
self.num_layers = 2
self.eps = 1e-12 # regularisation for phase: gradients explode otherwise
# one network for attention weights per degree
self.transform = nn.ModuleDict()
for o, m_in in self.f_in.structure_dict.items():
m_out = self.f_out.structure_dict[o]
self.transform[str(o)] = self._build_net(m_in, m_out)
def __repr__(self):
return f"AttentiveSelfInteractionSE3(in={self.f_in}, out={self.f_out})"
def _build_net(self, m_in: int, m_out):
n_hidden = m_in * m_out
cur_inpt = m_in * m_in
net = []
for i in range(1, self.num_layers):
net.append(nn.LayerNorm(int(cur_inpt)))
net.append(self.nonlin)
# TODO: implement cleaner init
net.append(
nn.Linear(cur_inpt, n_hidden, bias=(i == self.num_layers - 1)))
nn.init.kaiming_uniform_(net[-1].weight)
cur_inpt = n_hidden
return nn.Sequential(*net)
def forward(self, features, **kwargs):
output = {}
for k, v in features.items():
# v shape: [..., m, 2*k+1]
first_dims = v.shape[:-2]
m_in = self.f_in.structure_dict[int(k)]
m_out = self.f_out.structure_dict[int(k)]
assert v.shape[-2] == m_in
assert v.shape[-1] == 2 * int(k) + 1
# Compute the norms and normalized features
#norm = v.norm(p=2, dim=-1, keepdim=True).clamp_min(self.eps).expand_as(v)
#phase = v / norm # [..., m, 2*k+1]
scalars = torch.einsum('...ac,...bc->...ab', [v, v]) # [..., m_in, m_in]
scalars = scalars.view(*first_dims, m_in*m_in) # [..., m_in*m_in]
sign = scalars.sign()
scalars = scalars.abs_().clamp_min(self.eps)
scalars *= sign
# perform attention
att_weights = self.transform[str(k)](scalars) # [..., m_out*m_in]
att_weights = att_weights.view(*first_dims, m_out, m_in) # [..., m_out, m_in]
att_weights = F.softmax(input=att_weights, dim=-1)
# shape [..., m_out, 2*k+1]
# output[k] = torch.einsum('...nm,...md->...nd', [att_weights, phase])
output[k] = torch.einsum('...nm,...md->...nd', [att_weights, v])
return output
class GNormSE3(nn.Module):
"""Graph Norm-based SE(3)-equivariant nonlinearity.
Nonlinearities are important in SE(3) equivariant GCNs. They are also quite
expensive to compute, so it is convenient for them to share resources with
other layers, such as normalization. The general workflow is as follows:
> for feature type in features:
> norm, phase <- feature
> output = fnc(norm) * phase
where fnc: {R+}^m -> R^m is a learnable map from m norms to m scalars.
"""
def __init__(self, fiber, nonlin=nn.ReLU(inplace=True), num_layers: int=0):
"""Initializer.
Args:
fiber: Fiber() of feature multiplicities and types
nonlin: nonlinearity to use everywhere
num_layers: non-negative number of linear layers in fnc
"""
super().__init__()
self.fiber = fiber
self.nonlin = nonlin
self.num_layers = num_layers
# Regularization for computing phase: gradients explode otherwise
self.eps = 1e-12
# Norm mappings: 1 per feature type
self.transform = nn.ModuleDict()
for m, d in self.fiber.structure:
self.transform[str(d)] = self._build_net(int(m))
def __repr__(self):
return f"GNormSE3(num_layers={self.num_layers}, nonlin={self.nonlin})"
def _build_net(self, m: int):
net = []
for i in range(self.num_layers):
net.append(BN(int(m)))
net.append(self.nonlin)
# TODO: implement cleaner init
net.append(nn.Linear(m, m, bias=(i==self.num_layers-1)))
nn.init.kaiming_uniform_(net[-1].weight)
if self.num_layers == 0:
net.append(BN(int(m)))
net.append(self.nonlin)
return nn.Sequential(*net)
def forward(self, features, **kwargs):
output = {}
for k, v in features.items():
# Compute the norms and normalized features
# v shape: [...,m , 2*k+1]
norm = v.norm(2, -1, keepdim=True).clamp_min(self.eps).expand_as(v)
phase = v / norm
# Transform on norms
transformed = self.transform[str(k)](norm[...,0]).unsqueeze(-1)
# Nonlinearity on norm
output[k] = (transformed * phase).view(*v.shape)
return output
class BN(nn.Module):
"""SE(3)-equvariant batch/layer normalization"""
def __init__(self, m):
"""SE(3)-equvariant batch/layer normalization
Args:
m: int for number of output channels
"""
super().__init__()
self.bn = nn.LayerNorm(m)
def forward(self, x):
return self.bn(x)
class GConvSE3Partial(nn.Module):
"""Graph SE(3)-equivariant node -> edge layer"""
def __init__(self, f_in, f_out, edge_dim: int=0, x_ij=None):
"""SE(3)-equivariant partial convolution.
A partial convolution computes the inner product between a kernel and
each input channel, without summing over the result from each input
channel. This unfolded structure makes it amenable to be used for
computing the value-embeddings of the attention mechanism.
Args:
f_in: list of tuples [(multiplicities, type),...]
f_out: list of tuples [(multiplicities, type),...]
"""
super().__init__()
self.f_out = f_out
self.edge_dim = edge_dim
# adding/concatinating relative position to feature vectors
# 'cat' concatenates relative position & existing feature vector
# 'add' adds it, but only if multiplicity > 1
assert x_ij in [None, 'cat', 'add']
self.x_ij = x_ij
if x_ij == 'cat':
self.f_in = Fiber.combine(f_in, Fiber(structure=[(1,1)]))
else:
self.f_in = f_in
# Node -> edge weights
self.kernel_unary = nn.ModuleDict()
for (mi, di) in self.f_in.structure:
for (mo, do) in self.f_out.structure:
self.kernel_unary[f'({di},{do})'] = PairwiseConv(di, mi, do, mo, edge_dim=edge_dim)
def __repr__(self):
return f'GConvSE3Partial(structure={self.f_out})'
def udf_u_mul_e(self, d_out):
"""Compute the partial convolution for a single output feature type.
This function is set up as a User Defined Function in DGL.
Args:
d_out: output feature type
Returns:
node -> edge function handle
"""
def fnc(edges):
# Neighbor -> center messages
msg = 0
for m_in, d_in in self.f_in.structure:
# if type 1 and flag set, add relative position as feature
if self.x_ij == 'cat' and d_in == 1:
# relative positions
rel = (edges.dst['x'] - edges.src['x']).view(-1, 3, 1)
m_ori = m_in - 1
if m_ori == 0:
# no type 1 input feature, just use relative position
src = rel
else:
# features of src node, shape [edges, m_in*(2l+1), 1]
src = edges.src[f'{d_in}'].view(-1, m_ori*(2*d_in+1), 1)
# add to feature vector
src = torch.cat([src, rel], dim=1)
elif self.x_ij == 'add' and d_in == 1 and m_in > 1:
src = edges.src[f'{d_in}'].view(-1, m_in*(2*d_in+1), 1)
rel = (edges.dst['x'] - edges.src['x']).view(-1, 3, 1)
src[..., :3, :1] = src[..., :3, :1] + rel
else:
src = edges.src[f'{d_in}'].view(-1, m_in*(2*d_in+1), 1)
edge = edges.data[f'({d_in},{d_out})']
msg = msg + torch.matmul(edge, src)
msg = msg.view(msg.shape[0], -1, 2*d_out+1)
return {f'out{d_out}': msg.view(msg.shape[0], -1, 2*d_out+1)}
return fnc
def forward(self, h, G=None, r=None, basis=None, **kwargs):
"""Forward pass of the linear layer
Args:
h: dict of node-features
G: minibatch of (homo)graphs
r: inter-atomic distances
basis: pre-computed Q * Y
Returns:
tensor with new features [B, n_points, n_features_out]
"""
with G.local_scope():
# Add node features to local graph scope
for k, v in h.items():
G.ndata[k] = v
# Add edge features
if 'w' in G.edata.keys():
w = G.edata['w'] # shape: [#edges_in_batch, #bond_types]
feat = torch.cat([w, r], -1)
else:
feat = torch.cat([r, ], -1)
for (mi, di) in self.f_in.structure:
for (mo, do) in self.f_out.structure:
etype = f'({di},{do})'
G.edata[etype] = self.kernel_unary[etype](feat, basis)
# Perform message-passing for each output feature type
for d in self.f_out.degrees:
G.apply_edges(self.udf_u_mul_e(d))
return {f'{d}': G.edata[f'out{d}'] for d in self.f_out.degrees}
class GMABSE3(nn.Module):
"""An SE(3)-equivariant multi-headed self-attention module for DGL graphs."""
def __init__(self, f_value: Fiber, f_key: Fiber, n_heads: int):
"""SE(3)-equivariant MAB (multi-headed attention block) layer.
Args:
f_value: Fiber() object for value-embeddings
f_key: Fiber() object for key-embeddings
n_heads: number of heads
"""
super().__init__()
self.f_value = f_value
self.f_key = f_key
self.n_heads = n_heads
self.new_dgl = version.parse(dgl.__version__) > version.parse('0.4.4')
def __repr__(self):
return f'GMABSE3(n_heads={self.n_heads}, structure={self.f_value})'
def udf_u_mul_e(self, d_out):
"""Compute the weighted sum for a single output feature type.
This function is set up as a User Defined Function in DGL.
Args:
d_out: output feature type
Returns:
edge -> node function handle
"""
def fnc(edges):
# Neighbor -> center messages
attn = edges.data['a']
value = edges.data[f'v{d_out}']
# Apply attention weights
msg = attn.unsqueeze(-1).unsqueeze(-1) * value
return {'m': msg}
return fnc
def forward(self, v, k: Dict=None, q: Dict=None, G=None, **kwargs):
"""Forward pass of the linear layer
Args:
G: minibatch of (homo)graphs
v: dict of value edge-features
k: dict of key edge-features
q: dict of query node-features
Returns:
tensor with new features [B, n_points, n_features_out]
"""
with G.local_scope():
# Add node features to local graph scope
## We use the stacked tensor representation for attention
for m, d in self.f_value.structure:
G.edata[f'v{d}'] = v[f'{d}'].view(-1, self.n_heads, m//self.n_heads, 2*d+1)
G.edata['k'] = fiber2head(k, self.n_heads, self.f_key, squeeze=True) # [edges, heads, channels](?)
G.ndata['q'] = fiber2head(q, self.n_heads, self.f_key, squeeze=True) # [nodes, heads, channels](?)
# Compute attention weights
## Inner product between (key) neighborhood and (query) center
G.apply_edges(fn.e_dot_v('k', 'q', 'e'))
## Apply softmax
e = G.edata.pop('e')
if self.new_dgl:
# in dgl 5.3, e has an extra dimension compared to dgl 4.3
# the following, we get rid of this be reshaping
n_edges = G.edata['k'].shape[0]
e = e.view([n_edges, self.n_heads])
e = e / np.sqrt(self.f_key.n_features)
G.edata['a'] = edge_softmax(G, e)
# Perform attention-weighted message-passing
for d in self.f_value.degrees:
G.update_all(self.udf_u_mul_e(d), fn.sum('m', f'out{d}'))
output = {}
for m, d in self.f_value.structure:
output[f'{d}'] = G.ndata[f'out{d}'].view(-1, m, 2*d+1)
return output
class GSE3Res(nn.Module):
"""Graph attention block with SE(3)-equivariance and skip connection"""
def __init__(self, f_in: Fiber, f_out: Fiber, edge_dim: int=0, div: float=4,
n_heads: int=1, learnable_skip=True, skip='cat', selfint='1x1', x_ij=None):
super().__init__()
self.f_in = f_in
self.f_out = f_out
self.div = div
self.n_heads = n_heads
self.skip = skip # valid: 'cat', 'sum', None
# f_mid_out has same structure as 'f_out' but #channels divided by 'div'
# this will be used for the values
f_mid_out = {k: int(v // div) for k, v in self.f_out.structure_dict.items()}
self.f_mid_out = Fiber(dictionary=f_mid_out)
# f_mid_in has same structure as f_mid_out, but only degrees which are in f_in
# this will be used for keys and queries
# (queries are merely projected, hence degrees have to match input)
f_mid_in = {d: m for d, m in f_mid_out.items() if d in self.f_in.degrees}
self.f_mid_in = Fiber(dictionary=f_mid_in)
self.edge_dim = edge_dim
self.GMAB = nn.ModuleDict()
# Projections
self.GMAB['v'] = GConvSE3Partial(f_in, self.f_mid_out, edge_dim=edge_dim, x_ij=x_ij)
self.GMAB['k'] = GConvSE3Partial(f_in, self.f_mid_in, edge_dim=edge_dim, x_ij=x_ij)
self.GMAB['q'] = G1x1SE3(f_in, self.f_mid_in)
# Attention
self.GMAB['attn'] = GMABSE3(self.f_mid_out, self.f_mid_in, n_heads=n_heads)
# Skip connections
if self.skip == 'cat':
self.cat = GCat(self.f_mid_out, f_in)
if selfint == 'att':
self.project = GAttentiveSelfInt(self.cat.f_out, f_out)
elif selfint == '1x1':
self.project = G1x1SE3(self.cat.f_out, f_out, learnable=learnable_skip)
elif self.skip == 'sum':
self.project = G1x1SE3(self.f_mid_out, f_out, learnable=learnable_skip)
self.add = GSum(f_out, f_in)
# the following checks whether the skip connection would change
# the output fibre strucure; the reason can be that the input has
# more channels than the ouput (for at least one degree); this would
# then cause a (hard to debug) error in the next layer
assert self.add.f_out.structure_dict == f_out.structure_dict, \
'skip connection would change output structure'
def forward(self, features, G, **kwargs):
# Embeddings
v = self.GMAB['v'](features, G=G, **kwargs)
k = self.GMAB['k'](features, G=G, **kwargs)
q = self.GMAB['q'](features, G=G)
# Attention
z = self.GMAB['attn'](v, k=k, q=q, G=G)
if self.skip == 'cat':
z = self.cat(z, features)
z = self.project(z)
elif self.skip == 'sum':
# Skip + residual
z = self.project(z)
z = self.add(z, features)
return z
### Helper and wrapper functions
class GSum(nn.Module):
"""SE(3)-equvariant graph residual sum function."""
def __init__(self, f_x: Fiber, f_y: Fiber):
"""SE(3)-equvariant graph residual sum function.
Args:
f_x: Fiber() object for fiber of summands
f_y: Fiber() object for fiber of summands
"""
super().__init__()
self.f_x = f_x
self.f_y = f_y
self.f_out = Fiber.combine_max(f_x, f_y)
def __repr__(self):
return f"GSum(structure={self.f_out})"
def forward(self, x, y):
out = {}
for k in self.f_out.degrees:
k = str(k)
if (k in x) and (k in y):
if x[k].shape[1] > y[k].shape[1]:
diff = x[k].shape[1] - y[k].shape[1]
zeros = torch.zeros(x[k].shape[0], diff, x[k].shape[2]).to(y[k].device)
y[k] = torch.cat([y[k], zeros], 1)
elif x[k].shape[1] < y[k].shape[1]:
diff = y[k].shape[1] - x[k].shape[1]
zeros = torch.zeros(x[k].shape[0], diff, x[k].shape[2]).to(y[k].device)
x[k] = torch.cat([x[k], zeros], 1)
out[k] = x[k] + y[k]
elif k in x:
out[k] = x[k]
elif k in y:
out[k] = y[k]
return out
class GCat(nn.Module):
"""Concat only degrees which are in f_x"""
def __init__(self, f_x: Fiber, f_y: Fiber):
super().__init__()
self.f_x = f_x
self.f_y = f_y
f_out = {}
for k in f_x.degrees:
f_out[k] = f_x.dict[k]
if k in f_y.degrees:
f_out[k] += f_y.dict[k]
self.f_out = Fiber(dictionary=f_out)
def __repr__(self):
return f"GCat(structure={self.f_out})"
def forward(self, x, y):
out = {}
for k in self.f_out.degrees:
k = str(k)
if k in y:
out[k] = torch.cat([x[k], y[k]], 1)
else:
out[k] = x[k]
return out
class GAvgPooling(nn.Module):
"""Graph Average Pooling module."""
def __init__(self, type='0'):
super().__init__()
self.pool = AvgPooling()
self.type = type
def forward(self, features, G, **kwargs):
if self.type == '0':
h = features['0'][...,-1]
pooled = self.pool(G, h)
elif self.type == '1':
pooled = []
for i in range(3):
h_i = features['1'][..., i]
pooled.append(self.pool(G, h_i).unsqueeze(-1))
pooled = torch.cat(pooled, axis=-1)
pooled = {'1': pooled}
else:
print('GAvgPooling for type > 0 not implemented')
exit()
return pooled
class GMaxPooling(nn.Module):
"""Graph Max Pooling module."""
def __init__(self):
super().__init__()
self.pool = MaxPooling()
def forward(self, features, G, **kwargs):
h = features['0'][...,-1]
return self.pool(G, h)

View file

@ -0,0 +1,91 @@
#!/usr/bin/env python
# https://raw.githubusercontent.com/ahcm/ffindex/master/python/ffindex.py
'''
Created on Apr 30, 2014
@author: meiermark
'''
import sys
import mmap
from collections import namedtuple
FFindexEntry = namedtuple("FFindexEntry", "name, offset, length")
def read_index(ffindex_filename):
entries = []
fh = open(ffindex_filename)
for line in fh:
tokens = line.split("\t")
entries.append(FFindexEntry(tokens[0], int(tokens[1]), int(tokens[2])))
fh.close()
return entries
def read_data(ffdata_filename):
fh = open(ffdata_filename, "rb")
data = mmap.mmap(fh.fileno(), 0, prot=mmap.PROT_READ)
fh.close()
return data
def get_entry_by_name(name, index):
#TODO: bsearch
for entry in index:
if(name == entry.name):
return entry
return None
def read_entry_lines(entry, data):
lines = data[entry.offset:entry.offset + entry.length - 1].decode("utf-8").split("\n")
return lines
def read_entry_data(entry, data):
return data[entry.offset:entry.offset + entry.length - 1]
def write_entry(entries, data_fh, entry_name, offset, data):
data_fh.write(data[:-1])
data_fh.write(bytearray(1))
entry = FFindexEntry(entry_name, offset, len(data))
entries.append(entry)
return offset + len(data)
def write_entry_with_file(entries, data_fh, entry_name, offset, file_name):
with open(file_name, "rb") as fh:
data = bytearray(fh.read())
return write_entry(entries, data_fh, entry_name, offset, data)
def finish_db(entries, ffindex_filename, data_fh):
data_fh.close()
write_entries_to_db(entries, ffindex_filename)
def write_entries_to_db(entries, ffindex_filename):
sorted(entries, key=lambda x: x.name)
index_fh = open(ffindex_filename, "w")
for entry in entries:
index_fh.write("{name:.64}\t{offset}\t{length}\n".format(name=entry.name, offset=entry.offset, length=entry.length))
index_fh.close()
def write_entry_to_file(entry, data, file):
lines = read_lines(entry, data)
fh = open(file, "w")
for line in lines:
fh.write(line+"\n")
fh.close()

View file

@ -0,0 +1,216 @@
import numpy as np
import torch
PARAMS = {
"DMIN" : 2.0,
"DMAX" : 20.0,
"DBINS" : 36,
"ABINS" : 36,
}
# ============================================================
def get_pair_dist(a, b):
"""calculate pair distances between two sets of points
Parameters
----------
a,b : pytorch tensors of shape [batch,nres,3]
store Cartesian coordinates of two sets of atoms
Returns
-------
dist : pytorch tensor of shape [batch,nres,nres]
stores paitwise distances between atoms in a and b
"""
dist = torch.cdist(a, b, p=2)
return dist
# ============================================================
def get_ang(a, b, c):
"""calculate planar angles for all consecutive triples (a[i],b[i],c[i])
from Cartesian coordinates of three sets of atoms a,b,c
Parameters
----------
a,b,c : pytorch tensors of shape [batch,nres,3]
store Cartesian coordinates of three sets of atoms
Returns
-------
ang : pytorch tensor of shape [batch,nres]
stores resulting planar angles
"""
v = a - b
w = c - b
v /= torch.norm(v, dim=-1, keepdim=True)
w /= torch.norm(w, dim=-1, keepdim=True)
vw = torch.sum(v*w, dim=-1)
return torch.acos(vw)
# ============================================================
def get_dih(a, b, c, d):
"""calculate dihedral angles for all consecutive quadruples (a[i],b[i],c[i],d[i])
given Cartesian coordinates of four sets of atoms a,b,c,d
Parameters
----------
a,b,c,d : pytorch tensors of shape [batch,nres,3]
store Cartesian coordinates of four sets of atoms
Returns
-------
dih : pytorch tensor of shape [batch,nres]
stores resulting dihedrals
"""
b0 = a - b
b1 = c - b
b2 = d - c
b1 /= torch.norm(b1, dim=-1, keepdim=True)
v = b0 - torch.sum(b0*b1, dim=-1, keepdim=True)*b1
w = b2 - torch.sum(b2*b1, dim=-1, keepdim=True)*b1
x = torch.sum(v*w, dim=-1)
y = torch.sum(torch.cross(b1,v,dim=-1)*w, dim=-1)
return torch.atan2(y, x)
# ============================================================
def xyz_to_c6d(xyz, params=PARAMS):
"""convert cartesian coordinates into 2d distance
and orientation maps
Parameters
----------
xyz : pytorch tensor of shape [batch,nres,3,3]
stores Cartesian coordinates of backbone N,Ca,C atoms
Returns
-------
c6d : pytorch tensor of shape [batch,nres,nres,4]
stores stacked dist,omega,theta,phi 2D maps
"""
batch = xyz.shape[0]
nres = xyz.shape[1]
# three anchor atoms
N = xyz[:,:,0]
Ca = xyz[:,:,1]
C = xyz[:,:,2]
# recreate Cb given N,Ca,C
b = Ca - N
c = C - Ca
a = torch.cross(b, c, dim=-1)
Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
# 6d coordinates order: (dist,omega,theta,phi)
c6d = torch.zeros([batch,nres,nres,4],dtype=xyz.dtype,device=xyz.device)
dist = get_pair_dist(Cb,Cb)
dist[torch.isnan(dist)] = 999.9
c6d[...,0] = dist + 999.9*torch.eye(nres,device=xyz.device)[None,...]
b,i,j = torch.where(c6d[...,0]<params['DMAX'])
c6d[b,i,j,torch.full_like(b,1)] = get_dih(Ca[b,i], Cb[b,i], Cb[b,j], Ca[b,j])
c6d[b,i,j,torch.full_like(b,2)] = get_dih(N[b,i], Ca[b,i], Cb[b,i], Cb[b,j])
c6d[b,i,j,torch.full_like(b,3)] = get_ang(Ca[b,i], Cb[b,i], Cb[b,j])
# fix long-range distances
c6d[...,0][c6d[...,0]>=params['DMAX']] = 999.9
mask = torch.zeros((batch, nres,nres), dtype=xyz.dtype, device=xyz.device)
mask[b,i,j] = 1.0
return c6d, mask
def xyz_to_t2d(xyz_t, t0d, params=PARAMS):
"""convert template cartesian coordinates into 2d distance
and orientation maps
Parameters
----------
xyz_t : pytorch tensor of shape [batch,templ,nres,3,3]
stores Cartesian coordinates of template backbone N,Ca,C atoms
t0d: 0-D template features (HHprob, seqID, similarity) [batch, templ, 3]
Returns
-------
t2d : pytorch tensor of shape [batch,nres,nres,1+6+3]
stores stacked dist,omega,theta,phi 2D maps
"""
B, T, L = xyz_t.shape[:3]
c6d, mask = xyz_to_c6d(xyz_t.view(B*T,L,3,3), params=params)
c6d = c6d.view(B, T, L, L, 4)
mask = mask.view(B, T, L, L, 1)
#
dist = c6d[...,:1]*mask / params['DMAX'] # from 0 to 1 # (B, T, L, L, 1)
dist = torch.clamp(dist, 0.0, 1.0)
orien = torch.cat((torch.sin(c6d[...,1:]), torch.cos(c6d[...,1:])), dim=-1)*mask # (B, T, L, L, 6)
t0d = t0d.unsqueeze(2).unsqueeze(3).expand(-1, -1, L, L, -1)
#
t2d = torch.cat((dist, orien, t0d), dim=-1)
t2d[torch.isnan(t2d)] = 0.0
return t2d
# ============================================================
def c6d_to_bins(c6d,params=PARAMS):
"""bin 2d distance and orientation maps
"""
dstep = (params['DMAX'] - params['DMIN']) / params['DBINS']
astep = 2.0*np.pi / params['ABINS']
dbins = torch.linspace(params['DMIN']+dstep, params['DMAX'], params['DBINS'],dtype=c6d.dtype,device=c6d.device)
ab360 = torch.linspace(-np.pi+astep, np.pi, params['ABINS'],dtype=c6d.dtype,device=c6d.device)
ab180 = torch.linspace(astep, np.pi, params['ABINS']//2,dtype=c6d.dtype,device=c6d.device)
db = torch.bucketize(c6d[...,0].contiguous(),dbins)
ob = torch.bucketize(c6d[...,1].contiguous(),ab360)
tb = torch.bucketize(c6d[...,2].contiguous(),ab360)
pb = torch.bucketize(c6d[...,3].contiguous(),ab180)
ob[db==params['DBINS']] = params['ABINS']
tb[db==params['DBINS']] = params['ABINS']
pb[db==params['DBINS']] = params['ABINS']//2
return torch.stack([db,ob,tb,pb],axis=-1).to(torch.uint8)
# ============================================================
def dist_to_bins(dist,params=PARAMS):
"""bin 2d distance maps
"""
dstep = (params['DMAX'] - params['DMIN']) / params['DBINS']
db = torch.round((dist-params['DMIN']-dstep/2)/dstep)
db[db<0] = 0
db[db>params['DBINS']] = params['DBINS']
return db.long()
# ============================================================
def c6d_to_bins2(c6d,params=PARAMS):
"""bin 2d distance and orientation maps
"""
dstep = (params['DMAX'] - params['DMIN']) / params['DBINS']
astep = 2.0*np.pi / params['ABINS']
db = torch.round((c6d[...,0]-params['DMIN']-dstep/2)/dstep)
ob = torch.round((c6d[...,1]+np.pi-astep/2)/astep)
tb = torch.round((c6d[...,2]+np.pi-astep/2)/astep)
pb = torch.round((c6d[...,3]-astep/2)/astep)
# put all d<dmin into one bin
db[db<0] = 0
# synchronize no-contact bins
db[db>params['DBINS']] = params['DBINS']
ob[db==params['DBINS']] = params['ABINS']
tb[db==params['DBINS']] = params['ABINS']
pb[db==params['DBINS']] = params['ABINS']//2
return torch.stack([db,ob,tb,pb],axis=-1).long()

View file

@ -0,0 +1,255 @@
import numpy as np
import scipy
import scipy.spatial
import string
import os,re
import random
import util
import torch
from ffindex import *
to1letter = {
"ALA":'A', "ARG":'R', "ASN":'N', "ASP":'D', "CYS":'C',
"GLN":'Q', "GLU":'E', "GLY":'G', "HIS":'H', "ILE":'I',
"LEU":'L', "LYS":'K', "MET":'M', "PHE":'F', "PRO":'P',
"SER":'S', "THR":'T', "TRP":'W', "TYR":'Y', "VAL":'V' }
# read A3M and convert letters into
# integers in the 0..20 range,
def parse_a3m(filename):
msa = []
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
# read file line by line
for line in open(filename,"r"):
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
# remove lowercase letters and append to MSA
msa.append(line.translate(table))
# convert letters into numbers
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8)
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i
# treat all unknown characters as gaps
msa[msa > 20] = 20
return msa
# parse HHsearch output
def parse_hhr(filename, ffindex, idmax=105.0):
# labels present in the database
label_set = set([i.name for i in ffindex])
out = []
with open(filename, "r") as hhr:
# read .hhr into a list of lines
lines = [s.rstrip() for _,s in enumerate(hhr)]
# read list of all hits
start = lines.index("") + 2
stop = lines[start:].index("") + start
hits = []
for line in lines[start:stop]:
# ID of the hit
#label = re.sub('_','',line[4:10].strip())
label = line[4:10].strip()
# position in the query where the alignment starts
qstart = int(line[75:84].strip().split("-")[0])-1
# position in the template where the alignment starts
tstart = int(line[85:94].strip().split("-")[0])-1
hits.append([label, qstart, tstart, int(line[69:75])])
# get line numbers where each hit starts
start = [i for i,l in enumerate(lines) if l and l[0]==">"] # and l[1:].strip() in label_set]
# process hits
for idx,i in enumerate(start):
# skip if hit is too short
if hits[idx][3] < 10:
continue
# skip if template is not in the database
if hits[idx][0] not in label_set:
continue
# get hit statistics
p,e,s,_,seqid,sim,_,neff = [float(s) for s in re.sub('[=%]', ' ', lines[i+1]).split()[1::2]]
# skip too similar hits
if seqid > idmax:
continue
query = np.array(list(lines[i+4].split()[3]), dtype='|S1')
tmplt = np.array(list(lines[i+8].split()[3]), dtype='|S1')
simlr = np.array(list(lines[i+6][22:]), dtype='|S1').view(np.uint8)
abc = np.array(list(" =-.+|"), dtype='|S1').view(np.uint8)
for k in range(abc.shape[0]):
simlr[simlr == abc[k]] = k
confd = np.array(list(lines[i+11][22:]), dtype='|S1').view(np.uint8)
abc = np.array(list(" 0123456789"), dtype='|S1').view(np.uint8)
for k in range(abc.shape[0]):
confd[confd == abc[k]] = k
qj = np.cumsum(query!=b'-') + hits[idx][1]
tj = np.cumsum(tmplt!=b'-') + hits[idx][2]
# matched positions
matches = np.array([[q-1,t-1,s-1,c-1] for q,t,s,c in zip(qj,tj,simlr,confd) if s>0])
# skip short hits
ncol = matches.shape[0]
if ncol<10:
continue
# save hit
#out.update({hits[idx][0] : [matches,p/100,seqid/100,neff/10]})
out.append([hits[idx][0],matches,p/100,seqid/100,sim/10])
return out
# read and extract xyz coords of N,Ca,C atoms
# from a PDB file
def parse_pdb(filename):
lines = open(filename,'r').readlines()
N = np.array([[float(l[30:38]), float(l[38:46]), float(l[46:54])]
for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="N"])
Ca = np.array([[float(l[30:38]), float(l[38:46]), float(l[46:54])]
for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"])
C = np.array([[float(l[30:38]), float(l[38:46]), float(l[46:54])]
for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="C"])
xyz = np.stack([N,Ca,C], axis=0)
# indices of residues observed in the structure
idx = np.array([int(l[22:26]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"])
return xyz,idx
def parse_pdb_lines(lines):
# indices of residues observed in the structure
idx_s = [int(l[22:26]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"]
# 4 BB + up to 10 SC atoms
xyz = np.full((len(idx_s), 14, 3), np.nan, dtype=np.float32)
for l in lines:
if l[:4] != "ATOM":
continue
resNo, atom, aa = int(l[22:26]), l[12:16], l[17:20]
idx = idx_s.index(resNo)
for i_atm, tgtatm in enumerate(util.aa2long[util.aa2num[aa]]):
if tgtatm == atom:
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
break
# save atom mask
mask = np.logical_not(np.isnan(xyz[...,0]))
xyz[np.isnan(xyz[...,0])] = 0.0
return xyz,mask,np.array(idx_s)
def parse_templates(ffdb, hhr_fn, atab_fn, n_templ=10):
# process tabulated hhsearch output to get
# matched positions and positional scores
infile = atab_fn
hits = []
for l in open(infile, "r").readlines():
if l[0]=='>':
key = l[1:].split()[0]
hits.append([key,[],[]])
elif "score" in l or "dssp" in l:
continue
else:
hi = l.split()[:5]+[0.0,0.0,0.0]
hits[-1][1].append([int(hi[0]),int(hi[1])])
hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])])
# get per-hit statistics from an .hhr file
# (!!! assume that .hhr and .atab have the same hits !!!)
# [Probab, E-value, Score, Aligned_cols,
# Identities, Similarity, Sum_probs, Template_Neff]
lines = open(hhr_fn, "r").readlines()
pos = [i+1 for i,l in enumerate(lines) if l[0]=='>']
for i,posi in enumerate(pos):
hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]])
# parse templates from FFDB
for hi in hits:
#if hi[0] not in ffids:
# continue
entry = get_entry_by_name(hi[0], ffdb.index)
if entry == None:
continue
data = read_entry_lines(entry, ffdb.data)
hi += list(parse_pdb_lines(data))
# process hits
counter = 0
xyz,qmap,mask,f0d,f1d,ids = [],[],[],[],[],[]
for data in hits:
if len(data)<7:
continue
qi,ti = np.array(data[1]).T
_,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True)
ncol = sel1.shape[0]
if ncol < 10:
continue
ids.append(data[0])
f0d.append(data[3])
f1d.append(np.array(data[2])[sel1])
xyz.append(data[4][sel2])
mask.append(data[5][sel2])
qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1))
counter += 1
xyz = np.vstack(xyz).astype(np.float32)
qmap = np.vstack(qmap).astype(np.long)
f0d = np.vstack(f0d).astype(np.float32)
f1d = np.vstack(f1d).astype(np.float32)
ids = ids
return torch.from_numpy(xyz), torch.from_numpy(qmap), \
torch.from_numpy(f0d), torch.from_numpy(f1d), ids
def read_templates(qlen, ffdb, hhr_fn, atab_fn, n_templ=10):
xyz_t, qmap, t0d, t1d, ids = parse_templates(ffdb, hhr_fn, atab_fn)
npick = min(n_templ, len(ids))
sample = torch.arange(npick)
#
xyz = torch.full((npick, qlen, 3, 3), np.nan).float()
f1d = torch.zeros((npick, qlen, 3)).float()
f0d = list()
#
for i, nt in enumerate(sample):
sel = torch.where(qmap[:,1] == nt)[0]
pos = qmap[sel, 0]
xyz[i, pos] = xyz_t[sel, :3]
f1d[i, pos] = t1d[sel, :3]
f0d.append(torch.stack([t0d[nt,0]/100.0, t0d[nt, 4]/100.0, t0d[nt,5]], dim=-1))
return xyz, f1d, torch.stack(f0d, dim=0)

View file

@ -0,0 +1,284 @@
import math
import torch
import torch.nn.functional as F
from torch import nn
from functools import partial
# Original implementation from https://github.com/lucidrains/performer-pytorch
# helpers
def exists(val):
return val is not None
def empty(tensor):
return tensor.numel() == 0
def default(val, d):
return val if exists(val) else d
def get_module_device(module):
return next(module.parameters()).device
def find_modules(nn_module, type):
return [module for module in nn_module.modules() if isinstance(module, type)]
# kernel functions
def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None):
b, h, *_ = data.shape
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
ratio = (projection_matrix.shape[0] ** -0.5)
#projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)
projection = projection_matrix.unsqueeze(0).repeat(h, 1, 1)
projection = projection.unsqueeze(0).repeat(b, 1, 1, 1) # (b,h,j,d)
projection = projection.type_as(data)
data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)
diag_data = data ** 2
diag_data = torch.sum(diag_data, dim=-1)
diag_data = (diag_data / 2.0) * (data_normalizer ** 2)
diag_data = diag_data.unsqueeze(dim=-1)
if is_query:
data_dash = ratio * (
torch.exp(data_dash - diag_data -
torch.max(data_dash, dim=-1, keepdim=True).values) + eps)
else:
data_dash = ratio * (
torch.exp(data_dash - diag_data - torch.max(data_dash)) + eps)
return data_dash.type_as(data)
def generalized_kernel(data, *, projection_matrix, kernel_fn = nn.ReLU(inplace=True), kernel_epsilon = 0.001, normalize_data = True, device = None):
b, h, *_ = data.shape
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
if projection_matrix is None:
return kernel_fn(data_normalizer * data) + kernel_epsilon
data = data_normalizer*data
data = torch.matmul(data, projection_matrix.T)
data = kernel_fn(data) + kernel_epsilon
return data.type_as(data)
def orthogonal_matrix_chunk(cols, qr_uniform_q = False, device = None):
unstructured_block = torch.randn((cols, cols), device = device)
q, r = torch.linalg.qr(unstructured_block.cpu(), 'reduced')
q, r = map(lambda t: t.to(device), (q, r))
# proposed by @Parskatt
# to make sure Q is uniform https://arxiv.org/pdf/math-ph/0609050.pdf
if qr_uniform_q:
d = torch.diag(r, 0)
q *= d.sign()
return q.t()
def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, qr_uniform_q = False, device = None):
nb_full_blocks = int(nb_rows / nb_columns)
block_list = []
for _ in range(nb_full_blocks):
q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device)
block_list.append(q)
remaining_rows = nb_rows - nb_full_blocks * nb_columns
if remaining_rows > 0:
q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device)
block_list.append(q[:remaining_rows])
final_matrix = torch.cat(block_list)
if scaling == 0:
multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1)
elif scaling == 1:
multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device)
else:
raise ValueError(f'Invalid scaling {scaling}')
return torch.diag(multiplier) @ final_matrix
# linear attention classes with softmax kernel
# non-causal linear attention
def linear_attention(q, k, v):
L = k.shape[-2]
D_inv = 1. / torch.einsum('...nd,...d->...n', q, k.mean(dim=-2))
context = torch.einsum('...nd,...ne->...de', k/float(L), v)
del k, v
out = torch.einsum('...n,...nd->...nd', D_inv, q)
del D_inv, q
out = torch.einsum('...nd,...de->...ne', out, context)
return out
class FastAttention(nn.Module):
def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, generalized_attention = False, kernel_fn = nn.ReLU(inplace=True), qr_uniform_q = False, no_projection = False):
super().__init__()
nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
self.dim_heads = dim_heads
self.nb_features = nb_features
self.ortho_scaling = ortho_scaling
if not no_projection:
self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling, qr_uniform_q = qr_uniform_q)
projection_matrix = self.create_projection()
self.register_buffer('projection_matrix', projection_matrix)
self.generalized_attention = generalized_attention
self.kernel_fn = kernel_fn
# if this is turned on, no projection will be used
# queries and keys will be softmax-ed as in the original efficient attention paper
self.no_projection = no_projection
@torch.no_grad()
def redraw_projection_matrix(self, device):
projections = self.create_projection(device = device)
self.projection_matrix.copy_(projections)
del projections
def forward(self, q, k, v):
device = q.device
if self.no_projection:
q = q.softmax(dim = -1)
k.softmax(dim = -2)
elif self.generalized_attention:
create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device)
q, k = map(create_kernel, (q, k))
else:
create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device)
q = create_kernel(q, is_query = True)
k = create_kernel(k, is_query = False)
attn_fn = linear_attention
out = attn_fn(q, k, v)
return out
# classes
class ReZero(nn.Module):
def __init__(self, fn):
super().__init__()
self.g = nn.Parameter(torch.tensor(1e-3))
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.g
class PreScaleNorm(nn.Module):
def __init__(self, dim, fn, eps=1e-5):
super().__init__()
self.fn = fn
self.g = nn.Parameter(torch.ones(1))
self.eps = eps
def forward(self, x, **kwargs):
n = torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps)
x = x / n * self.g
return self.fn(x, **kwargs)
class PreLayerNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class Chunk(nn.Module):
def __init__(self, chunks, fn, along_dim = -1):
super().__init__()
self.dim = along_dim
self.chunks = chunks
self.fn = fn
def forward(self, x, **kwargs):
if self.chunks == 1:
return self.fn(x, **kwargs)
chunks = x.chunk(self.chunks, dim = self.dim)
return torch.cat([self.fn(c, **kwargs) for c in chunks], dim = self.dim)
class SelfAttention(nn.Module):
def __init__(self, dim, k_dim=None, heads = 8, local_heads = 0, local_window_size = 256, nb_features = None, feature_redraw_interval = 1000, generalized_attention = False, kernel_fn = nn.ReLU(inplace=True), qr_uniform_q = False, dropout = 0., no_projection = False):
super().__init__()
assert dim % heads == 0, 'dimension must be divisible by number of heads'
dim_head = dim // heads
inner_dim = dim_head * heads
if k_dim == None:
k_dim = dim
self.fast_attention = FastAttention(dim_head, nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q, no_projection = no_projection)
self.heads = heads
self.dim = dim
self.to_query = nn.Linear(dim, inner_dim)
self.to_key = nn.Linear(k_dim, inner_dim)
self.to_value = nn.Linear(k_dim, inner_dim)
self.to_out = nn.Linear(inner_dim, dim)
self.dropout = nn.Dropout(dropout, inplace=True)
self.feature_redraw_interval = feature_redraw_interval
self.register_buffer("calls_since_last_redraw", torch.tensor(0))
self.max_tokens = 2**16
def check_redraw_projections(self):
if not self.training:
return
if exists(self.feature_redraw_interval) and self.calls_since_last_redraw >= self.feature_redraw_interval:
device = get_module_device(self)
fast_attentions = find_modules(self, FastAttention)
for fast_attention in fast_attentions:
fast_attention.redraw_projection_matrix(device)
self.calls_since_last_redraw.zero_()
return
self.calls_since_last_redraw += 1
def _batched_forward(self, q, k, v):
b1, h, n1 = q.shape[:3]
out = torch.empty((b1, h, n1, self.dim//h), dtype=q.dtype, device=q.device)
shift = self.max_tokens // n1
for i_b in range(0, b1, shift):
start = i_b
end = min(i_b+shift, b1)
out[start:end] = self.fast_attention(q[start:end], k[start:end], v[start:end])
return out
def forward(self, query, key, value, **kwargs):
self.check_redraw_projections()
b1, n1, _, h = *query.shape, self.heads
b2, n2, _, h = *key.shape, self.heads
q = self.to_query(query)
k = self.to_key(key)
v = self.to_value(value)
q = q.reshape(b1, n1, h, -1).permute(0,2,1,3) # (b, h, n, d)
k = k.reshape(b2, n2, h, -1).permute(0,2,1,3)
v = v.reshape(b2, n2, h, -1).permute(0,2,1,3)
if b1*n1 > self.max_tokens or b2*n2 > self.max_tokens:
out = self._batched_forward(q, k, v)
else:
out = self.fast_attention(q, k, v)
out = out.permute(0,2,1,3).reshape(b1,n1,-1)
out = self.to_out(out)
return self.dropout(out)

View file

@ -0,0 +1,317 @@
import sys, os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
from parsers import parse_a3m, read_templates
from RoseTTAFoldModel import RoseTTAFoldModule_e2e
import util
from collections import namedtuple
from ffindex import *
from kinematics import xyz_to_c6d, c6d_to_bins2, xyz_to_t2d
from trFold import TRFold
script_dir = '/'.join(os.path.dirname(os.path.realpath(__file__)).split('/')[:-1])
NBIN = [37, 37, 37, 19]
MODEL_PARAM ={
"n_module" : 8,
"n_module_str" : 4,
"n_module_ref" : 4,
"n_layer" : 1,
"d_msa" : 384 ,
"d_pair" : 288,
"d_templ" : 64,
"n_head_msa" : 12,
"n_head_pair" : 8,
"n_head_templ" : 4,
"d_hidden" : 64,
"r_ff" : 4,
"n_resblock" : 1,
"p_drop" : 0.1,
"use_templ" : True,
"performer_N_opts": {"nb_features": 64},
"performer_L_opts": {"nb_features": 64}
}
SE3_param = {
"num_layers" : 2,
"num_channels" : 16,
"num_degrees" : 2,
"l0_in_features": 32,
"l0_out_features": 8,
"l1_in_features": 3,
"l1_out_features": 3,
"num_edge_features": 32,
"div": 2,
"n_heads": 4
}
REF_param = {
"num_layers" : 3,
"num_channels" : 32,
"num_degrees" : 3,
"l0_in_features": 32,
"l0_out_features": 8,
"l1_in_features": 3,
"l1_out_features": 3,
"num_edge_features": 32,
"div": 4,
"n_heads": 4
}
MODEL_PARAM['SE3_param'] = SE3_param
MODEL_PARAM['REF_param'] = REF_param
# params for the folding protocol
fold_params = {
"SG7" : np.array([[[-2,3,6,7,6,3,-2]]])/21,
"SG9" : np.array([[[-21,14,39,54,59,54,39,14,-21]]])/231,
"DCUT" : 19.5,
"ALPHA" : 1.57,
# TODO: add Cb to the motif
"NCAC" : np.array([[-0.676, -1.294, 0. ],
[ 0. , 0. , 0. ],
[ 1.5 , -0.174, 0. ]], dtype=np.float32),
"CLASH" : 2.0,
"PCUT" : 0.5,
"DSTEP" : 0.5,
"ASTEP" : np.deg2rad(10.0),
"XYZRAD" : 7.5,
"WANG" : 0.1,
"WCST" : 0.1
}
fold_params["SG"] = fold_params["SG9"]
class Predictor():
def __init__(self, model_dir=None, use_cpu=False):
if model_dir == None:
self.model_dir = "%s/weights"%(script_dir)
else:
self.model_dir = model_dir
#
# define model name
self.model_name = "RoseTTAFold"
if torch.cuda.is_available() and (not use_cpu):
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.active_fn = nn.Softmax(dim=1)
# define model & load model
self.model = RoseTTAFoldModule_e2e(**MODEL_PARAM).to(self.device)
could_load = self.load_model(self.model_name)
if not could_load:
print ("ERROR: failed to load model")
sys.exit()
def load_model(self, model_name, suffix='e2e'):
chk_fn = "%s/%s_%s.pt"%(self.model_dir, model_name, suffix)
if not os.path.exists(chk_fn):
return False
checkpoint = torch.load(chk_fn, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'], strict=True)
return True
def predict(self, a3m_fn, out_prefix, Ls, templ_npz=None, window=1000, shift=100):
msa = parse_a3m(a3m_fn)
N, L = msa.shape
#
if templ_npz != None:
templ = np.load(templ_npz)
xyz_t = torch.from_numpy(templ["xyz_t"])
t1d = torch.from_numpy(templ["t1d"])
t0d = torch.from_numpy(templ["t0d"])
else:
xyz_t = torch.full((1, L, 3, 3), np.nan).float()
t1d = torch.zeros((1, L, 3)).float()
t0d = torch.zeros((1,3)).float()
self.model.eval()
with torch.no_grad():
#
msa = torch.tensor(msa).long().view(1, -1, L)
idx_pdb_orig = torch.arange(L).long().view(1, L)
idx_pdb = torch.arange(L).long().view(1, L)
L_prev = 0
for L_i in Ls[:-1]:
idx_pdb[:,L_prev+L_i:] += 500 # it was 200 originally.
L_prev += L_i
seq = msa[:,0]
#
# template features
xyz_t = xyz_t.float().unsqueeze(0)
t1d = t1d.float().unsqueeze(0)
t0d = t0d.float().unsqueeze(0)
t2d = xyz_to_t2d(xyz_t, t0d)
#
# do cropped prediction
if L > window*2:
prob_s = [np.zeros((L,L,NBIN[i]), dtype=np.float32) for i in range(4)]
count_1d = np.zeros((L,), dtype=np.float32)
count_2d = np.zeros((L,L), dtype=np.float32)
node_s = np.zeros((L,MODEL_PARAM['d_msa']), dtype=np.float32)
#
grids = np.arange(0, L-window+shift, shift)
ngrids = grids.shape[0]
print("ngrid: ", ngrids)
print("grids: ", grids)
print("windows: ", window)
for i in range(ngrids):
for j in range(i, ngrids):
start_1 = grids[i]
end_1 = min(grids[i]+window, L)
start_2 = grids[j]
end_2 = min(grids[j]+window, L)
sel = np.zeros((L)).astype(np.bool)
sel[start_1:end_1] = True
sel[start_2:end_2] = True
input_msa = msa[:,:,sel]
mask = torch.sum(input_msa==20, dim=-1) < 0.5*sel.sum() # remove too gappy sequences
input_msa = input_msa[mask].unsqueeze(0)
input_msa = input_msa[:,:1000].to(self.device)
input_idx = idx_pdb[:,sel].to(self.device)
input_idx_orig = idx_pdb_orig[:,sel]
input_seq = input_msa[:,0].to(self.device)
#
# Select template
input_t1d = t1d[:,:,sel].to(self.device) # (B, T, L, 3)
input_t2d = t2d[:,:,sel][:,:,:,sel].to(self.device)
#
print ("running crop: %d-%d/%d-%d"%(start_1, end_1, start_2, end_2), input_msa.shape)
with torch.cuda.amp.autocast():
logit_s, node, init_crds, pred_lddt = self.model(input_msa, input_seq, input_idx, t1d=input_t1d, t2d=input_t2d, return_raw=True)
#
# Not sure How can we merge init_crds.....
pred_lddt = torch.clamp(pred_lddt, 0.0, 1.0)
sub_idx = input_idx_orig[0]
sub_idx_2d = np.ix_(sub_idx, sub_idx)
count_2d[sub_idx_2d] += 1.0
count_1d[sub_idx] += 1.0
node_s[sub_idx] += node[0].cpu().numpy()
for i_logit, logit in enumerate(logit_s):
prob = self.active_fn(logit.float()) # calculate distogram
prob = prob.squeeze(0).permute(1,2,0).cpu().numpy()
prob_s[i_logit][sub_idx_2d] += prob
del logit_s, node
#
for i in range(4):
prob_s[i] = prob_s[i] / count_2d[:,:,None]
prob_in = np.concatenate(prob_s, axis=-1)
node_s = node_s / count_1d[:, None]
#
node_s = torch.tensor(node_s).to(self.device).unsqueeze(0)
seq = msa[:,0].to(self.device)
idx_pdb = idx_pdb.to(self.device)
prob_in = torch.tensor(prob_in).to(self.device).unsqueeze(0)
with torch.cuda.amp.autocast():
xyz, lddt = self.model(node_s, seq, idx_pdb, prob_s=prob_in, refine_only=True)
print (lddt.mean())
else:
msa = msa[:,:1000].to(self.device)
seq = msa[:,0]
idx_pdb = idx_pdb.to(self.device)
t1d = t1d[:,:10].to(self.device)
t2d = t2d[:,:10].to(self.device)
with torch.cuda.amp.autocast():
logit_s, _, xyz, lddt = self.model(msa, seq, idx_pdb, t1d=t1d, t2d=t2d)
print (lddt.mean())
prob_s = list()
for logit in logit_s:
prob = self.active_fn(logit.float()) # distogram
prob = prob.reshape(-1, L, L).permute(1,2,0).cpu().numpy()
prob_s.append(prob)
np.savez_compressed("%s.npz"%out_prefix, dist=prob_s[0].astype(np.float16), \
omega=prob_s[1].astype(np.float16),\
theta=prob_s[2].astype(np.float16),\
phi=prob_s[3].astype(np.float16))
# run TRFold
prob_trF = list()
for prob in prob_s:
prob = torch.tensor(prob).permute(2,0,1).to(self.device)
prob += 1e-8
prob = prob / torch.sum(prob, dim=0)[None]
prob_trF.append(prob)
xyz = xyz[0, :, 1]
TRF = TRFold(prob_trF, fold_params)
xyz = TRF.fold(xyz, batch=15, lr=0.1, nsteps=200)
print (xyz.shape, lddt[0].shape, seq[0].shape)
self.write_pdb(seq[0], xyz, Ls, Bfacts=lddt[0], prefix=out_prefix)
def write_pdb(self, seq, atoms, Ls, Bfacts=None, prefix=None):
chainIDs = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
L = len(seq)
filename = "%s.pdb"%prefix
ctr = 1
with open(filename, 'wt') as f:
if Bfacts == None:
Bfacts = np.zeros(L)
else:
Bfacts = torch.clamp( Bfacts, 0, 1)
for i,s in enumerate(seq):
if (len(atoms.shape)==2):
resNo = i+1
chain = "A"
for i_chain in range(len(Ls)-1,0,-1):
tot_res = sum(Ls[:i_chain])
if i+1 > tot_res:
chain = chainIDs[i_chain]
resNo = i+1 - tot_res
break
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, " CA ", util.num2aa[s],
chain, resNo, atoms[i,0], atoms[i,1], atoms[i,2],
1.0, Bfacts[i] ) )
ctr += 1
elif atoms.shape[1]==3:
resNo = i+1
chain = "A"
for i_chain in range(len(Ls)-1,0,-1):
tot_res = sum(Ls[:i_chain])
if i+1 > tot_res:
chain = chainIDs[i_chain]
resNo = i+1 - tot_res
break
for j,atm_j in enumerate((" N "," CA "," C ")):
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, atm_j, util.num2aa[s],
chain, resNo, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2],
1.0, Bfacts[i] ) )
ctr += 1
def get_args():
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("-m", dest="model_dir", default="%s/weights"%(script_dir),
help="Path to pre-trained network weights [%s/weights]"%script_dir)
parser.add_argument("-i", dest="a3m_fn", required=True,
help="Input multiple sequence alignments (in a3m format)")
parser.add_argument("-o", dest="out_prefix", required=True,
help="Prefix for output file. The output files will be [out_prefix].npz and [out_prefix].pdb")
parser.add_argument("-Ls", dest="Ls", required=True, nargs="+", type=int,
help="The length of the each subunit (e.g. 220 400)")
parser.add_argument("--templ_npz", default=None,
help='''npz file containing complex template information (xyz_t, t1d, t0d). If not provided, zero matrices will be given as templates
- xyz_t: N, CA, C coordinates of complex templates (T, L, 3, 3) For the unaligned region, it should be NaN
- t1d: 1-D features from HHsearch results (score, SS, probab column from atab file) (T, L, 3). For the unaligned region, it should be zeros
- t0d: 0-D features from HHsearch (Probability/100.0, Ideintities/100.0, Similarity fro hhr file) (T, 3)''')
parser.add_argument("--cpu", dest='use_cpu', default=False, action='store_true')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
if not os.path.exists("%s.npz"%args.out_prefix):
pred = Predictor(model_dir=args.model_dir, use_cpu=args.use_cpu)
pred.predict(args.a3m_fn, args.out_prefix, args.Ls, templ_npz=args.templ_npz)

View file

@ -0,0 +1,324 @@
import sys, os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
from parsers import parse_a3m, read_templates
from RoseTTAFoldModel import RoseTTAFoldModule_e2e
import util
from collections import namedtuple
from ffindex import *
from kinematics import xyz_to_c6d, c6d_to_bins2, xyz_to_t2d
from trFold import TRFold
script_dir = '/'.join(os.path.dirname(os.path.realpath(__file__)).split('/')[:-1])
NBIN = [37, 37, 37, 19]
MODEL_PARAM ={
"n_module" : 8,
"n_module_str" : 4,
"n_module_ref" : 4,
"n_layer" : 1,
"d_msa" : 384 ,
"d_pair" : 288,
"d_templ" : 64,
"n_head_msa" : 12,
"n_head_pair" : 8,
"n_head_templ" : 4,
"d_hidden" : 64,
"r_ff" : 4,
"n_resblock" : 1,
"p_drop" : 0.0,
"use_templ" : True,
"performer_N_opts": {"nb_features": 64},
"performer_L_opts": {"nb_features": 64}
}
SE3_param = {
"num_layers" : 2,
"num_channels" : 16,
"num_degrees" : 2,
"l0_in_features": 32,
"l0_out_features": 8,
"l1_in_features": 3,
"l1_out_features": 3,
"num_edge_features": 32,
"div": 2,
"n_heads": 4
}
REF_param = {
"num_layers" : 3,
"num_channels" : 32,
"num_degrees" : 3,
"l0_in_features": 32,
"l0_out_features": 8,
"l1_in_features": 3,
"l1_out_features": 3,
"num_edge_features": 32,
"div": 4,
"n_heads": 4
}
MODEL_PARAM['SE3_param'] = SE3_param
MODEL_PARAM['REF_param'] = REF_param
# params for the folding protocol
fold_params = {
"SG7" : np.array([[[-2,3,6,7,6,3,-2]]])/21,
"SG9" : np.array([[[-21,14,39,54,59,54,39,14,-21]]])/231,
"DCUT" : 19.5,
"ALPHA" : 1.57,
# TODO: add Cb to the motif
"NCAC" : np.array([[-0.676, -1.294, 0. ],
[ 0. , 0. , 0. ],
[ 1.5 , -0.174, 0. ]], dtype=np.float32),
"CLASH" : 2.0,
"PCUT" : 0.5,
"DSTEP" : 0.5,
"ASTEP" : np.deg2rad(10.0),
"XYZRAD" : 7.5,
"WANG" : 0.1,
"WCST" : 0.1
}
fold_params["SG"] = fold_params["SG9"]
class Predictor():
def __init__(self, model_dir=None, use_cpu=False):
if model_dir == None:
self.model_dir = "%s/models"%(os.path.dirname(os.path.realpath(__file__)))
else:
self.model_dir = model_dir
#
# define model name
self.model_name = "RoseTTAFold"
if torch.cuda.is_available() and (not use_cpu):
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.active_fn = nn.Softmax(dim=1)
# define model & load model
self.model = RoseTTAFoldModule_e2e(**MODEL_PARAM).to(self.device)
def load_model(self, model_name, suffix='e2e'):
chk_fn = "%s/%s_%s.pt"%(self.model_dir, model_name, suffix)
if not os.path.exists(chk_fn):
return False
checkpoint = torch.load(chk_fn, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'], strict=True)
return True
def predict(self, a3m_fn, out_prefix, hhr_fn=None, atab_fn=None, window=150, shift=75):
msa = parse_a3m(a3m_fn)
N, L = msa.shape
#
if hhr_fn != None:
xyz_t, t1d, t0d = read_templates(L, ffdb, hhr_fn, atab_fn, n_templ=10)
else:
xyz_t = torch.full((1, L, 3, 3), np.nan).float()
t1d = torch.zeros((1, L, 3)).float()
t0d = torch.zeros((1,3)).float()
#
msa = torch.tensor(msa).long().view(1, -1, L)
idx_pdb = torch.arange(L).long().view(1, L)
seq = msa[:,0]
#
# template features
xyz_t = xyz_t.float().unsqueeze(0)
t1d = t1d.float().unsqueeze(0)
t0d = t0d.float().unsqueeze(0)
t2d = xyz_to_t2d(xyz_t, t0d)
could_load = self.load_model(self.model_name, suffix="e2e")
if not could_load:
print ("ERROR: failed to load model")
sys.exit()
self.model.eval()
with torch.no_grad():
# do cropped prediction if protein is too big
if L > window*2:
prob_s = [np.zeros((L,L,NBIN[i]), dtype=np.float32) for i in range(4)]
count_1d = np.zeros((L,), dtype=np.float32)
count_2d = np.zeros((L,L), dtype=np.float32)
node_s = np.zeros((L,MODEL_PARAM['d_msa']), dtype=np.float32)
#
grids = np.arange(0, L-window+shift, shift)
ngrids = grids.shape[0]
print("ngrid: ", ngrids)
print("grids: ", grids)
print("windows: ", window)
for i in range(ngrids):
for j in range(i, ngrids):
start_1 = grids[i]
end_1 = min(grids[i]+window, L)
start_2 = grids[j]
end_2 = min(grids[j]+window, L)
sel = np.zeros((L)).astype(np.bool)
sel[start_1:end_1] = True
sel[start_2:end_2] = True
input_msa = msa[:,:,sel]
mask = torch.sum(input_msa==20, dim=-1) < 0.5*sel.sum() # remove too gappy sequences
input_msa = input_msa[mask].unsqueeze(0)
input_msa = input_msa[:,:1000].to(self.device)
input_idx = idx_pdb[:,sel].to(self.device)
input_seq = input_msa[:,0].to(self.device)
#
# Select template
input_t1d = t1d[:,:,sel].to(self.device) # (B, T, L, 3)
input_t2d = t2d[:,:,sel][:,:,:,sel].to(self.device)
#
print ("running crop: %d-%d/%d-%d"%(start_1, end_1, start_2, end_2), input_msa.shape)
with torch.cuda.amp.autocast():
logit_s, node, init_crds, pred_lddt = self.model(input_msa, input_seq, input_idx, t1d=input_t1d, t2d=input_t2d, return_raw=True)
#
# Not sure How can we merge init_crds.....
sub_idx = input_idx[0].cpu()
sub_idx_2d = np.ix_(sub_idx, sub_idx)
count_2d[sub_idx_2d] += 1.0
count_1d[sub_idx] += 1.0
node_s[sub_idx] += node[0].cpu().numpy()
for i_logit, logit in enumerate(logit_s):
prob = self.active_fn(logit.float()) # calculate distogram
prob = prob.squeeze(0).permute(1,2,0).cpu().numpy()
prob_s[i_logit][sub_idx_2d] += prob
del logit_s, node
#
# combine all crops
for i in range(4):
prob_s[i] = prob_s[i] / count_2d[:,:,None]
prob_in = np.concatenate(prob_s, axis=-1)
node_s = node_s / count_1d[:, None]
#
# Do iterative refinement using SE(3)-Transformers
# clear cache memory
torch.cuda.empty_cache()
#
node_s = torch.tensor(node_s).to(self.device).unsqueeze(0)
seq = msa[:,0].to(self.device)
idx_pdb = idx_pdb.to(self.device)
prob_in = torch.tensor(prob_in).to(self.device).unsqueeze(0)
with torch.cuda.amp.autocast():
xyz, lddt = self.model(node_s, seq, idx_pdb, prob_s=prob_in, refine_only=True)
else:
msa = msa[:,:1000].to(self.device)
seq = msa[:,0]
idx_pdb = idx_pdb.to(self.device)
t1d = t1d[:,:10].to(self.device)
t2d = t2d[:,:10].to(self.device)
with torch.cuda.amp.autocast():
logit_s, _, xyz, lddt = self.model(msa, seq, idx_pdb, t1d=t1d, t2d=t2d)
prob_s = list()
for logit in logit_s:
prob = self.active_fn(logit.float()) # distogram
prob = prob.reshape(-1, L, L).permute(1,2,0).cpu().numpy()
prob_s.append(prob)
np.savez_compressed("%s.npz"%(out_prefix), dist=prob_s[0].astype(np.float16), \
omega=prob_s[1].astype(np.float16),\
theta=prob_s[2].astype(np.float16),\
phi=prob_s[3].astype(np.float16))
self.write_pdb(seq[0], xyz[0], idx_pdb[0], Bfacts=lddt[0], prefix="%s_init"%(out_prefix))
# run TRFold
prob_trF = list()
for prob in prob_s:
prob = torch.tensor(prob).permute(2,0,1).to(self.device)
prob += 1e-8
prob = prob / torch.sum(prob, dim=0)[None]
prob_trF.append(prob)
xyz = xyz[0, :, 1]
TRF = TRFold(prob_trF, fold_params)
xyz = TRF.fold(xyz, batch=15, lr=0.1, nsteps=200)
xyz = xyz.detach().cpu().numpy()
# add O and Cb
N = xyz[:,0,:]
CA = xyz[:,1,:]
C = xyz[:,2,:]
O = self.extend(np.roll(N, -1, axis=0), CA, C, 1.231, 2.108, -3.142)
xyz = np.concatenate((xyz, O[:,None,:]), axis=1)
self.write_pdb(seq[0], xyz, idx_pdb[0], Bfacts=lddt[0], prefix=out_prefix)
def extend(self, a,b,c, L,A,D):
'''
input: 3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral
output: 4th coord
'''
N = lambda x: x/np.sqrt(np.square(x).sum(-1,keepdims=True) + 1e-8)
bc = N(b-c)
n = N(np.cross(b-a, bc))
m = [bc,np.cross(n,bc),n]
d = [L*np.cos(A), L*np.sin(A)*np.cos(D), -L*np.sin(A)*np.sin(D)]
return c + sum([m*d for m,d in zip(m,d)])
def write_pdb(self, seq, atoms, idx, Bfacts=None, prefix=None):
L = len(seq)
filename = "%s.pdb"%prefix
ctr = 1
with open(filename, 'wt') as f:
if Bfacts == None:
Bfacts = np.zeros(L)
else:
Bfacts = torch.clamp( Bfacts, 0, 1)
for i,s in enumerate(seq):
if (len(atoms.shape)==2):
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, " CA ", util.num2aa[s],
"A", idx[i]+1, atoms[i,0], atoms[i,1], atoms[i,2],
1.0, Bfacts[i] ) )
ctr += 1
elif atoms.shape[1]==3:
for j,atm_j in enumerate((" N "," CA "," C ")):
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, atm_j, util.num2aa[s],
"A", idx[i]+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2],
1.0, Bfacts[i] ) )
ctr += 1
elif atoms.shape[1]==4:
for j,atm_j in enumerate((" N "," CA "," C ", " O ")):
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, atm_j, util.num2aa[s],
"A", idx[i]+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2],
1.0, Bfacts[i] ) )
ctr += 1
def get_args():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-m", dest="model_dir", default="%s/weights"%(script_dir),
help="Path to pre-trained network weights [%s/weights]"%script_dir)
parser.add_argument("-i", dest="a3m_fn", required=True,
help="Input multiple sequence alignments (in a3m format)")
parser.add_argument("-o", dest="out_prefix", required=True,
help="Prefix for output file. The output files will be [out_prefix].npz and [out_prefix].pdb")
parser.add_argument("--hhr", default=None,
help="HHsearch output file (hhr file). If not provided, zero matrices will be given as templates")
parser.add_argument("--atab", default=None,
help="HHsearch output file (atab file)")
parser.add_argument("--db", default="%s/pdb100_2021Mar03/pdb100_2021Mar03"%script_dir,
help="Path to template database [%s/pdb100_2021Mar03]"%script_dir)
parser.add_argument("--cpu", dest='use_cpu', default=False, action='store_true')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
FFDB=args.db
FFindexDB = namedtuple("FFindexDB", "index, data")
ffdb = FFindexDB(read_index(FFDB+'_pdb.ffindex'),
read_data(FFDB+'_pdb.ffdata'))
if not os.path.exists("%s.npz"%args.out_prefix):
pred = Predictor(model_dir=args.model_dir, use_cpu=args.use_cpu)
pred.predict(args.a3m_fn, args.out_prefix, args.hhr, args.atab)

View file

@ -0,0 +1,200 @@
import sys, os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
from parsers import parse_a3m, read_templates
from RoseTTAFoldModel import RoseTTAFoldModule
import util
from collections import namedtuple
from ffindex import *
from kinematics import xyz_to_c6d, c6d_to_bins2, xyz_to_t2d
script_dir = '/'.join(os.path.dirname(os.path.realpath(__file__)).split('/')[:-1])
NBIN = [37, 37, 37, 19]
MODEL_PARAM ={
"n_module" : 8,
"n_module_str" : 4,
"n_layer" : 1,
"d_msa" : 384 ,
"d_pair" : 288,
"d_templ" : 64,
"n_head_msa" : 12,
"n_head_pair" : 8,
"n_head_templ" : 4,
"d_hidden" : 64,
"r_ff" : 4,
"n_resblock" : 1,
"p_drop" : 0.1,
"use_templ" : True,
"performer_N_opts": {"nb_features": 64},
"performer_L_opts": {"nb_features": 64}
}
SE3_param = {
"num_layers" : 2,
"num_channels" : 16,
"num_degrees" : 2,
"l0_in_features": 32,
"l0_out_features": 8,
"l1_in_features": 3,
"l1_out_features": 3,
"num_edge_features": 32,
"div": 2,
"n_heads": 4
}
MODEL_PARAM['SE3_param'] = SE3_param
class Predictor():
def __init__(self, model_dir=None, use_cpu=False):
if model_dir == None:
self.model_dir = "%s/models"%(os.path.dirname(os.path.realpath(__file__)))
else:
self.model_dir = model_dir
#
# define model name
self.model_name = "RoseTTAFold"
if torch.cuda.is_available() and (not use_cpu):
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.active_fn = nn.Softmax(dim=1)
# define model & load model
self.model = RoseTTAFoldModule(**MODEL_PARAM).to(self.device)
could_load = self.load_model(self.model_name)
if not could_load:
print ("ERROR: failed to load model")
sys.exit()
def load_model(self, model_name, suffix='pyrosetta'):
chk_fn = "%s/%s_%s.pt"%(self.model_dir, model_name, suffix)
if not os.path.exists(chk_fn):
return False
checkpoint = torch.load(chk_fn, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'], strict=False)
return True
def predict(self, a3m_fn, out_prefix, hhr_fn=None, atab_fn=None, window=150, shift=50):
msa = parse_a3m(a3m_fn)
N, L = msa.shape
#
if hhr_fn != None:
xyz_t, t1d, t0d = read_templates(L, ffdb, hhr_fn, atab_fn, n_templ=25)
else:
xyz_t = torch.full((1, L, 3, 3), np.nan).float()
t1d = torch.zeros((1, L, 3)).float()
t0d = torch.zeros((1,3)).float()
self.model.eval()
with torch.no_grad():
#
msa = torch.tensor(msa).long().view(1, -1, L)
idx_pdb = torch.arange(L).long().view(1, L)
seq = msa[:,0]
#
# template features
xyz_t = xyz_t.float().unsqueeze(0)
t1d = t1d.float().unsqueeze(0)
t0d = t0d.float().unsqueeze(0)
t2d = xyz_to_t2d(xyz_t, t0d)
#
# do cropped prediction
if L > window*2:
prob_s = [np.zeros((L,L,NBIN[i]), dtype=np.float32) for i in range(4)]
count = np.zeros((L,L), dtype=np.float32)
#
grids = np.arange(0, L-window+shift, shift)
ngrids = grids.shape[0]
print("ngrid: ", ngrids)
print("grids: ", grids)
print("windows: ", window)
for i in range(ngrids):
for j in range(i, ngrids):
start_1 = grids[i]
end_1 = min(grids[i]+window, L)
start_2 = grids[j]
end_2 = min(grids[j]+window, L)
sel = np.zeros((L)).astype(np.bool)
sel[start_1:end_1] = True
sel[start_2:end_2] = True
input_msa = msa[:,:,sel]
mask = torch.sum(input_msa==20, dim=-1) < 0.5*sel.sum() # remove too gappy sequences
input_msa = input_msa[mask].unsqueeze(0)
input_msa = input_msa[:,:1000].to(self.device)
input_idx = idx_pdb[:,sel].to(self.device)
input_seq = input_msa[:,0].to(self.device)
#
input_t1d = t1d[:,:,sel].to(self.device)
input_t2d = t2d[:,:,sel][:,:,:,sel].to(self.device)
#
print ("running crop: %d-%d/%d-%d"%(start_1, end_1, start_2, end_2), input_msa.shape)
with torch.cuda.amp.autocast():
logit_s, init_crds, pred_lddt = self.model(input_msa, input_seq, input_idx, t1d=input_t1d, t2d=input_t2d)
#
pred_lddt = torch.clamp(pred_lddt, 0.0, 1.0)
weight = pred_lddt[0][:,None] + pred_lddt[0][None,:]
weight = weight.cpu().numpy() + 1e-8
sub_idx = input_idx[0].cpu()
sub_idx_2d = np.ix_(sub_idx, sub_idx)
count[sub_idx_2d] += weight
for i_logit, logit in enumerate(logit_s):
prob = self.active_fn(logit.float()) # calculate distogram
prob = prob.squeeze(0).permute(1,2,0).cpu().numpy()
prob_s[i_logit][sub_idx_2d] += weight[:,:,None]*prob
for i in range(4):
prob_s[i] = prob_s[i] / count[:,:,None]
else:
msa = msa[:,:1000].to(self.device)
seq = msa[:,0]
idx_pdb = idx_pdb.to(self.device)
t1d = t1d.to(self.device)
t2d = t2d.to(self.device)
logit_s, init_crds, pred_lddt = self.model(msa, seq, idx_pdb, t1d=t1d, t2d=t2d)
prob_s = list()
for logit in logit_s:
prob = self.active_fn(logit.float()) # distogram
prob = prob.reshape(-1, L, L).permute(1,2,0).cpu().numpy()
prob_s.append(prob)
np.savez_compressed("%s.npz"%out_prefix, dist=prob_s[0].astype(np.float16), \
omega=prob_s[1].astype(np.float16),\
theta=prob_s[2].astype(np.float16),\
phi=prob_s[3].astype(np.float16))
def get_args():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-m", dest="model_dir", default="%s/weights"%(script_dir),
help="Path to pre-trained network weights [%s/weights]"%script_dir)
parser.add_argument("-i", dest="a3m_fn", required=True,
help="Input multiple sequence alignments (in a3m format)")
parser.add_argument("-o", dest="out_prefix", required=True,
help="Prefix for output file. The output file will be [out_prefix].npz")
parser.add_argument("--hhr", default=None,
help="HHsearch output file (hhr file). If not provided, zero matrices will be given as templates")
parser.add_argument("--atab", default=None,
help="HHsearch output file (atab file)")
parser.add_argument("--db", default="%s/pdb100_2021Mar03/pdb100_2021Mar03"%script_dir,
help="Path to template database [%s/pdb100_2021Mar03]"%script_dir)
parser.add_argument("--cpu", dest='use_cpu', default=False, action='store_true')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
FFDB=args.db
FFindexDB = namedtuple("FFindexDB", "index, data")
ffdb = FFindexDB(read_index(FFDB+'_pdb.ffindex'),
read_data(FFDB+'_pdb.ffdata'))
if not os.path.exists("%s.npz"%args.out_prefix):
pred = Predictor(model_dir=args.model_dir, use_cpu=args.use_cpu)
pred.predict(args.a3m_fn, args.out_prefix, args.hhr, args.atab)

View file

@ -0,0 +1,95 @@
import torch
import torch.nn as nn
# original resblock
class ResBlock2D(nn.Module):
def __init__(self, n_c, kernel=3, dilation=1, p_drop=0.15):
super(ResBlock2D, self).__init__()
padding = self._get_same_padding(kernel, dilation)
layer_s = list()
layer_s.append(nn.Conv2d(n_c, n_c, kernel, padding=padding, dilation=dilation, bias=False))
layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6))
layer_s.append(nn.ELU(inplace=True))
# dropout
layer_s.append(nn.Dropout(p_drop))
# convolution
layer_s.append(nn.Conv2d(n_c, n_c, kernel, dilation=dilation, padding=padding, bias=False))
layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6))
self.layer = nn.Sequential(*layer_s)
self.final_activation = nn.ELU(inplace=True)
def _get_same_padding(self, kernel, dilation):
return (kernel + (kernel - 1) * (dilation - 1) - 1) // 2
def forward(self, x):
out = self.layer(x)
return self.final_activation(x + out)
# pre-activation bottleneck resblock
class ResBlock2D_bottleneck(nn.Module):
def __init__(self, n_c, kernel=3, dilation=1, p_drop=0.15):
super(ResBlock2D_bottleneck, self).__init__()
padding = self._get_same_padding(kernel, dilation)
n_b = n_c // 2 # bottleneck channel
layer_s = list()
# pre-activation
layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6))
layer_s.append(nn.ELU(inplace=True))
# project down to n_b
layer_s.append(nn.Conv2d(n_c, n_b, 1, bias=False))
layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6))
layer_s.append(nn.ELU(inplace=True))
# convolution
layer_s.append(nn.Conv2d(n_b, n_b, kernel, dilation=dilation, padding=padding, bias=False))
layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6))
layer_s.append(nn.ELU(inplace=True))
# dropout
layer_s.append(nn.Dropout(p_drop))
# project up
layer_s.append(nn.Conv2d(n_b, n_c, 1, bias=False))
self.layer = nn.Sequential(*layer_s)
def _get_same_padding(self, kernel, dilation):
return (kernel + (kernel - 1) * (dilation - 1) - 1) // 2
def forward(self, x):
out = self.layer(x)
return x + out
class ResidualNetwork(nn.Module):
def __init__(self, n_block, n_feat_in, n_feat_block, n_feat_out,
dilation=[1,2,4,8], block_type='orig', p_drop=0.15):
super(ResidualNetwork, self).__init__()
layer_s = list()
# project to n_feat_block
if n_feat_in != n_feat_block:
layer_s.append(nn.Conv2d(n_feat_in, n_feat_block, 1, bias=False))
if block_type =='orig': # should acitivate input
layer_s.append(nn.InstanceNorm2d(n_feat_block, affine=True, eps=1e-6))
layer_s.append(nn.ELU(inplace=True))
# add resblocks
for i_block in range(n_block):
d = dilation[i_block%len(dilation)]
if block_type == 'orig':
res_block = ResBlock2D(n_feat_block, kernel=3, dilation=d, p_drop=p_drop)
else:
res_block = ResBlock2D_bottleneck(n_feat_block, kernel=3, dilation=d, p_drop=p_drop)
layer_s.append(res_block)
if n_feat_out != n_feat_block:
# project to n_feat_out
layer_s.append(nn.Conv2d(n_feat_block, n_feat_out, 1))
self.layer = nn.Sequential(*layer_s)
def forward(self, x):
output = self.layer(x)
return output

View file

@ -0,0 +1,252 @@
import numpy as np
import torch
import torch.nn as nn
def perturb_init(xyz, batch, noise=0.5):
L = xyz.shape[0]
pert = torch.tensor(np.random.uniform(noise, size=(batch, L, 3)), device=xyz.device)
xyz = xyz.unsqueeze(0) + pert.detach()
return xyz
def Q2R(Q):
'''convert quaternions to rotation matrices'''
b,l,_ = Q.shape
w,x,y,z = Q[...,0],Q[...,1],Q[...,2],Q[...,3]
xx,xy,xz,xw = x*x, x*y, x*z, x*w
yy,yz,yw = y*y, y*z, y*w
zz,zw = z*z, z*w
R = torch.stack([1-2*yy-2*zz, 2*xy-2*zw, 2*xz+2*yw,
2*xy+2*zw, 1-2*xx-2*zz, 2*yz-2*xw,
2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy],dim=-1).view(b,l,3,3)
return R
def get_cb(N,Ca,C):
"""recreate Cb given N,Ca,C"""
b = Ca - N
c = C - Ca
a = torch.cross(b, c, dim=-1)
Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
return Cb
# ============================================================
def get_ang(a, b, c):
"""calculate planar angles for all consecutive triples (a[i],b[i],c[i])
from Cartesian coordinates of three sets of atoms a,b,c
Parameters
----------
a,b,c : pytorch tensors of shape [batch,nres,3]
store Cartesian coordinates of three sets of atoms
Returns
-------
ang : pytorch tensor of shape [batch,nres]
stores resulting planar angles
"""
v = a - b
w = c - b
v = v / torch.norm(v, dim=-1, keepdim=True)
w = w / torch.norm(w, dim=-1, keepdim=True)
vw = torch.sum(v*w, dim=-1)
return torch.acos(vw)
# ============================================================
def get_dih(a, b, c, d):
"""calculate dihedral angles for all consecutive quadruples (a[i],b[i],c[i],d[i])
given Cartesian coordinates of four sets of atoms a,b,c,d
Parameters
----------
a,b,c,d : pytorch tensors of shape [batch,nres,3]
store Cartesian coordinates of four sets of atoms
Returns
-------
dih : pytorch tensor of shape [batch,nres]
stores resulting dihedrals
"""
b0 = a - b
b1 = c - b
b2 = d - c
b1 = b1 / torch.norm(b1, dim=-1, keepdim=True)
v = b0 - torch.sum(b0*b1, dim=-1, keepdim=True)*b1
w = b2 - torch.sum(b2*b1, dim=-1, keepdim=True)*b1
x = torch.sum(v*w, dim=-1)
y = torch.sum(torch.cross(b1,v,dim=-1)*w, dim=-1)
return torch.atan2(y, x)
class TRFold():
def __init__(self, pred, params):
self.pred = pred
self.params = params
self.device = self.pred[0].device
# dfire background correction for distograms
self.bkgd = (torch.linspace(4.25,19.75,32,device=self.device)/
self.params['DCUT'])**self.params['ALPHA']
# background correction for phi
ang = torch.linspace(0,np.pi,19,device=self.device)[:-1]
self.bkgp = 0.5*(torch.cos(ang)-torch.cos(ang+np.deg2rad(10.0)))
# Sav-Gol filter
self.sg = torch.from_numpy(self.params['SG']).float().to(self.device)
# paddings for distograms:
# left - linear clash; right - zeroes
padRsize = self.sg.shape[-1]//2+3
padLsize = padRsize + 8
padR = torch.zeros(padRsize,device=self.device)
padL = torch.arange(1,padLsize+1,device=self.device).flip(0)*self.params['CLASH']
self.padR = padR[:,None]
self.padL = padL[:,None]
# backbone motif
self.ncac = torch.from_numpy(self.params['NCAC']).to(self.device)
def akima(self, y,h):
''' Akima spline coefficients (boundaries trimmed to [2:-2])
https://doi.org/10.1145/321607.321609 '''
m = (y[:,1:]-y[:,:-1])/h
#m += 1e-3*torch.randn(m.shape, device=m.device)
m4m3 = torch.abs(m[:,3:]-m[:,2:-1])
m2m1 = torch.abs(m[:,1:-2]-m[:,:-3])
t = (m4m3*m[:,1:-2] + m2m1*m[:,2:-1])/(m4m3+m2m1)
t[torch.isnan(t)] = 0.0
dy = y[:,3:-2]-y[:,2:-3]
coef = torch.stack([y[:,2:-3],
t[:,:-1],
(3*dy/h - 2*t[:,:-1] - t[:,1:])/h,
(t[:,:-1]+t[:,1:] - 2*dy/h)/h**2], dim=-1)
return coef
def fold(self, xyz, batch=32, lr=0.8, nsteps=100):
pd,po,pt,pp = self.pred
L = pd.shape[-1]
p20 = (6.0-pd[-1]-po[-1]-pt[-1]-pp[-1]-(pt[-1]+pp[-1]).T)/6
i,j = torch.triu_indices(L,L,1,device=self.device)
sel = torch.where(p20[i,j]>self.params['PCUT'])[0]
# indices for dist and omega (symmetric)
i_s,j_s = i[sel], j[sel]
# indices for theta and phi (asymmetric)
i_a,j_a = torch.hstack([i_s,j_s]), torch.hstack([j_s,i_s])
# background-corrected initial restraints
cstd = -torch.log(pd[4:36,i_s,j_s]/self.bkgd[:,None])
csto = -torch.log(po[0:36,i_s,j_s]/(1./36)) # omega and theta
cstt = -torch.log(pt[0:36,i_a,j_a]/(1./36)) # are almost uniform
cstp = -torch.log(pp[0:18,i_a,j_a]/self.bkgp[:,None])
# padded restraints
pad = self.sg.shape[-1]//2+3
cstd = torch.cat([self.padL+cstd[0],cstd,self.padR+cstd[-1]],dim=0)
csto = torch.cat([csto[-pad:],csto,csto[:pad]],dim=0)
cstt = torch.cat([cstt[-pad:],cstt,cstt[:pad]],dim=0)
cstp = torch.cat([cstp[:pad].flip(0),cstp,cstp[-pad:].flip(0)],dim=0)
# smoothed restraints
cstd,csto,cstt,cstp = [nn.functional.conv1d(cst.T.unsqueeze(1),self.sg)[:,0]
for cst in [cstd,csto,cstt,cstp]]
# force distance restraints vanish at long distances
cstd = cstd-cstd[:,-1][:,None]
# akima spline coefficients
coefd = self.akima(cstd, self.params['DSTEP']).detach()
coefo = self.akima(csto, self.params['ASTEP']).detach()
coeft = self.akima(cstt, self.params['ASTEP']).detach()
coefp = self.akima(cstp, self.params['ASTEP']).detach()
astep = self.params['ASTEP']
ko = torch.arange(i_s.shape[0],device=self.device).repeat(batch)
kt = torch.arange(i_a.shape[0],device=self.device).repeat(batch)
# initial Ca placement using EDM+minimization
xyz = perturb_init(xyz, batch) # (batch, L, 3)
# optimization variables: T - shift vectors, Q - rotation quaternions
T = torch.zeros_like(xyz,device=self.device,requires_grad=True)
Q = torch.randn([batch,L,4],device=self.device,requires_grad=True)
bb0 = self.ncac[None,:,None,:].repeat(batch,1,L,1)
opt = torch.optim.Adam([T,Q], lr=lr)
for step in range(nsteps):
R = Q2R(Q/torch.norm(Q,dim=-1,keepdim=True))
bb = torch.einsum("blij,bklj->bkli",R,bb0)+(xyz+T)[:,None]
# TODO: include Cb in the motif
N,Ca,C = bb[:,0],bb[:,1],bb[:,2]
Cb = get_cb(N,Ca,C)
o = get_dih(Ca[:,i_s],Cb[:,i_s],Cb[:,j_s],Ca[:,j_s]) + np.pi
t = get_dih(N[:,i_a],Ca[:,i_a],Cb[:,i_a],Cb[:,j_a]) + np.pi
p = get_ang(Ca[:,i_a],Cb[:,i_a],Cb[:,j_a])
dij = torch.norm(Cb[:,i_s]-Cb[:,j_s],dim=-1)
b,k = torch.where(dij<20.0)
dk = dij[b,k]
#coord = [coord/step-0.5 for coord,step in zip([dij,o,t,p],[dstep,astep,astep,astep])]
#bins = [torch.ceil(c).long() for c in coord]
#delta = [torch.frac(c) for c in coord]
kbin = torch.ceil((dk-0.25)/0.5).long()
dx = (dk-0.25)%0.5
c = coefd[k,kbin]
lossd = c[:,0]+c[:,1]*dx+c[:,2]*dx**2+c[:,3]*dx**3
# omega
obin = torch.ceil((o.view(-1)-astep/2)/astep).long()
do = (o.view(-1)-astep/2)%astep
co = coefo[ko,obin]
losso = (co[:,0]+co[:,1]*do+co[:,2]*do**2+co[:,3]*do**3).view(batch,-1) #.mean(1)
# theta
tbin = torch.ceil((t.view(-1)-astep/2)/astep).long()
dt = (t.view(-1)-astep/2)%astep
ct = coeft[kt,tbin]
losst = (ct[:,0]+ct[:,1]*dt+ct[:,2]*dt**2+ct[:,3]*dt**3).view(batch,-1) #.mean(1)
# phi
pbin = torch.ceil((p.view(-1)-astep/2)/astep).long()
dp = (p.view(-1)-astep/2)%astep
cp = coefp[kt,pbin]
lossp = (cp[:,0]+cp[:,1]*dp+cp[:,2]*dp**2+cp[:,3]*dp**3).view(batch,-1) #.mean(1)
# restrain geometry of peptide bonds
loss_nc = (torch.norm(C[:,:-1]-N[:,1:],dim=-1)-1.32868)**2
loss_cacn = (get_ang(Ca[:,:-1], C[:,:-1], N[:,1:]) - 2.02807)**2
loss_canc = (get_ang(Ca[:,1:], N[:,1:], C[:,:-1]) - 2.12407)**2
loss_geom = loss_nc.mean(1) + loss_cacn.mean(1) + loss_canc.mean(1)
loss_ang = losst.mean(1) + losso.mean(1) + lossp.mean(1)
# coefficient for ramping up geometric restraints during minimization
coef = (1.0+step)/nsteps
loss = lossd.mean() + self.params['WANG']*loss_ang.mean() + coef*self.params['WCST']*loss_geom.mean()
opt.zero_grad()
loss.backward()
opt.step()
lossd = torch.stack([lossd[b==i].mean() for i in range(batch)])
loss = lossd + self.params['WANG']*loss_ang + self.params['WCST']*loss_geom
minidx = torch.argmin(loss)
return bb[minidx].permute(1,0,2)

View file

@ -0,0 +1,253 @@
import sys
import numpy as np
import torch
num2aa=[
'ALA','ARG','ASN','ASP','CYS',
'GLN','GLU','GLY','HIS','ILE',
'LEU','LYS','MET','PHE','PRO',
'SER','THR','TRP','TYR','VAL',
]
aa2num= {x:i for i,x in enumerate(num2aa)}
# minimal sc atom representation (Nx8)
aa2short=[
(" N "," CA "," C "," CB ", None, None, None, None), # ala
(" N "," CA "," C "," CB "," CG "," CD "," NE "," CZ "), # arg
(" N "," CA "," C "," CB "," CG "," OD1", None, None), # asn
(" N "," CA "," C "," CB "," CG "," OD1", None, None), # asp
(" N "," CA "," C "," CB "," SG ", None, None, None), # cys
(" N "," CA "," C "," CB "," CG "," CD "," OE1", None), # gln
(" N "," CA "," C "," CB "," CG "," CD "," OE1", None), # glu
(" N "," CA "," C ", None, None, None, None, None), # gly
(" N "," CA "," C "," CB "," CG "," ND1", None, None), # his
(" N "," CA "," C "," CB "," CG1"," CD1", None, None), # ile
(" N "," CA "," C "," CB "," CG "," CD1", None, None), # leu
(" N "," CA "," C "," CB "," CG "," CD "," CE "," NZ "), # lys
(" N "," CA "," C "," CB "," CG "," SD "," CE ", None), # met
(" N "," CA "," C "," CB "," CG "," CD1", None, None), # phe
(" N "," CA "," C "," CB "," CG "," CD ", None, None), # pro
(" N "," CA "," C "," CB "," OG ", None, None, None), # ser
(" N "," CA "," C "," CB "," OG1", None, None, None), # thr
(" N "," CA "," C "," CB "," CG "," CD1", None, None), # trp
(" N "," CA "," C "," CB "," CG "," CD1", None, None), # tyr
(" N "," CA "," C "," CB "," CG1", None, None, None), # val
]
# full sc atom representation (Nx14)
aa2long=[
(" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # ala
(" N "," CA "," C "," O "," CB "," CG "," CD "," NE "," CZ "," NH1"," NH2", None, None, None), # arg
(" N "," CA "," C "," O "," CB "," CG "," OD1"," ND2", None, None, None, None, None, None), # asn
(" N "," CA "," C "," O "," CB "," CG "," OD1"," OD2", None, None, None, None, None, None), # asp
(" N "," CA "," C "," O "," CB "," SG ", None, None, None, None, None, None, None, None), # cys
(" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," NE2", None, None, None, None, None), # gln
(" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," OE2", None, None, None, None, None), # glu
(" N "," CA "," C "," O ", None, None, None, None, None, None, None, None, None, None), # gly
(" N "," CA "," C "," O "," CB "," CG "," ND1"," CD2"," CE1"," NE2", None, None, None, None), # his
(" N "," CA "," C "," O "," CB "," CG1"," CG2"," CD1", None, None, None, None, None, None), # ile
(" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2", None, None, None, None, None, None), # leu
(" N "," CA "," C "," O "," CB "," CG "," CD "," CE "," NZ ", None, None, None, None, None), # lys
(" N "," CA "," C "," O "," CB "," CG "," SD "," CE ", None, None, None, None, None, None), # met
(" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ ", None, None, None), # phe
(" N "," CA "," C "," O "," CB "," CG "," CD ", None, None, None, None, None, None, None), # pro
(" N "," CA "," C "," O "," CB "," OG ", None, None, None, None, None, None, None, None), # ser
(" N "," CA "," C "," O "," CB "," OG1"," CG2", None, None, None, None, None, None, None), # thr
(" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE2"," CE3"," NE1"," CZ2"," CZ3"," CH2"), # trp
(" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ "," OH ", None, None), # tyr
(" N "," CA "," C "," O "," CB "," CG1"," CG2", None, None, None, None, None, None, None), # val
]
# build the "alternate" sc mapping
aa2longalt=[
(" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # ala
(" N "," CA "," C "," O "," CB "," CG "," CD "," NE "," CZ "," NH2"," NH1", None, None, None), # arg
(" N "," CA "," C "," O "," CB "," CG "," OD1"," ND2", None, None, None, None, None, None), # asn
(" N "," CA "," C "," O "," CB "," CG "," OD2"," OD1", None, None, None, None, None, None), # asp
(" N "," CA "," C "," O "," CB "," SG ", None, None, None, None, None, None, None, None), # cys
(" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," NE2", None, None, None, None, None), # gln
(" N "," CA "," C "," O "," CB "," CG "," CD "," OE2"," OE1", None, None, None, None, None), # glu
(" N "," CA "," C "," O ", None, None, None, None, None, None, None, None, None, None), # gly
(" N "," CA "," C "," O "," CB "," CG "," ND1"," CD2"," CE1"," NE2", None, None, None, None), # his
(" N "," CA "," C "," O "," CB "," CG1"," CG2"," CD1", None, None, None, None, None, None), # ile
(" N "," CA "," C "," O "," CB "," CG "," CD2"," CD1", None, None, None, None, None, None), # leu
(" N "," CA "," C "," O "," CB "," CG "," CD "," CE "," NZ ", None, None, None, None, None), # lys
(" N "," CA "," C "," O "," CB "," CG "," SD "," CE ", None, None, None, None, None, None), # met
(" N "," CA "," C "," O "," CB "," CG "," CD2"," CD1"," CE2"," CE1"," CZ ", None, None, None), # phe
(" N "," CA "," C "," O "," CB "," CG "," CD ", None, None, None, None, None, None, None), # pro
(" N "," CA "," C "," O "," CB "," OG ", None, None, None, None, None, None, None, None), # ser
(" N "," CA "," C "," O "," CB "," OG1"," CG2", None, None, None, None, None, None, None), # thr
(" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE2"," CE3"," NE1"," CZ2"," CZ3"," CH2"), # trp
(" N "," CA "," C "," O "," CB "," CG "," CD2"," CD1"," CE2"," CE1"," CZ "," OH ", None, None), # tyr
(" N "," CA "," C "," O "," CB "," CG2"," CG1", None, None, None, None, None, None, None), # val
]
# build "deterministic" atoms
# see notebook (se3_experiments.ipynb for derivation)
aa2frames=[
[], # ala
[ # arg
[' NH1', ' CZ ', ' NE ', ' CD ', [-0.7218378782272339, 1.0856682062149048, -0.006118079647421837]],
[' NH2', ' CZ ', ' NE ', ' CD ', [-0.6158039569854736, -1.1400136947631836, 0.006467342376708984]]],
[ # asn
[' ND2', ' CG ', ' CB ', ' OD1', [-0.6304131746292114, -1.1431225538253784, 0.02364802360534668]]],
[ # asp
[' OD2', ' CG ', ' CB ', ' OD1', [-0.5972501039505005, -1.0955055952072144, 0.04530305415391922]]],
[], # cys
[ # gln
[' NE2', ' CD ', ' CG ', ' OE1', [-0.6558755040168762, -1.1324536800384521, 0.026521772146224976]]],
[ # glu
[' OE2', ' CD ', ' CG ', ' OE1', [-0.5578438639640808, -1.1161314249038696, -0.015464287251234055]]],
[], # gly
[ # his
[' CD2', ' CG ', ' CB ', ' ND1', [-0.7502505779266357, -1.1680538654327393, 0.0005368441343307495]],
[' CE1', ' CG ', ' CB ', ' ND1', [-2.0262467861175537, 0.539483368396759, -0.004495501518249512]],
[' NE2', ' CG ', ' CB ', ' ND1', [-2.0761325359344482, -0.8199722766876221, -0.0018703639507293701]]],
[ # ile
[' CG2', ' CB ', ' CA ', ' CG1', [-0.6059935688972473, -0.8108057379722595, 1.1861376762390137]]],
[ # leu
[' CD2', ' CG ', ' CB ', ' CD1', [-0.5942193269729614, -0.7693282961845398, -1.1914138793945312]]],
[], # lys
[], # met
[ # phe
[' CD2', ' CG ', ' CB ', ' CD1', [-0.7164441347122192, -1.197853446006775, 0.06416648626327515]],
[' CE1', ' CG ', ' CB ', ' CD1', [-2.0785865783691406, 1.2366485595703125, 0.08100450038909912]],
[' CE2', ' CG ', ' CB ', ' CD1', [-2.107091188430786, -1.178497076034546, 0.13524535298347473]],
[' CZ ', ' CG ', ' CB ', ' CD1', [-2.786630630493164, 0.03873880207538605, 0.14633776247501373]]],
[], # pro
[], # ser
[ # thr
[' CG2', ' CB ', ' CA ', ' OG1', [-0.6842088103294373, -0.6709619164466858, 1.2105456590652466]]],
[ # trp
[' CD2', ' CG ', ' CB ', ' CD1', [-0.8550368547439575, -1.0790592432022095, 0.09017711877822876]],
[' NE1', ' CG ', ' CB ', ' CD1', [-2.1863200664520264, 0.8064242601394653, 0.08350661396980286]],
[' CE2', ' CG ', ' CB ', ' CD1', [-2.1801204681396484, -0.5795643329620361, 0.14015203714370728]],
[' CE3', ' CG ', ' CB ', ' CD1', [-0.605582594871521, -2.4733362197875977, 0.16200461983680725]],
[' CE2', ' CG ', ' CB ', ' CD1', [-2.1801204681396484, -0.5795643329620361, 0.14015203714370728]],
[' CZ2', ' CG ', ' CB ', ' CD1', [-3.2672977447509766, -1.473116159439087, 0.250858873128891]],
[' CZ3', ' CG ', ' CB ', ' CD1', [-1.6969941854476929, -3.3360071182250977, 0.264143705368042]],
[' CH2', ' CG ', ' CB ', ' CD1', [-3.009331703186035, -2.8451972007751465, 0.3059283494949341]]],
[ # tyr
[' CD2', ' CG ', ' CB ', ' CD1', [-0.69439297914505, -1.2123756408691406, -0.009198814630508423]],
[' CE1', ' CG ', ' CB ', ' CD1', [-2.104464054107666, 1.1910505294799805, -0.014679580926895142]],
[' CE2', ' CG ', ' CB ', ' CD1', [-2.0857787132263184, -1.2231677770614624, -0.024517983198165894]],
[' CZ ', ' CG ', ' CB ', ' CD1', [-2.7897322177886963, -0.021470561623573303, -0.026979409158229828]],
[' OH ', ' CG ', ' CB ', ' CD1', [-4.1559271812438965, -0.029129385948181152, -0.044720835983753204]]],
[ # val
[' CG2', ' CB ', ' CA ', ' CG1', [-0.6258467435836792, -0.7654698491096497, -1.1894742250442505]]],
]
# O from frame (C,N-1,CA)
bb2oframe=[-0.5992066264152527, -1.0820008516311646, 0.0001476481556892395]
# build the mapping from indices in reduced representation to
# indices in the full representation
# N x 14 x 6 = <base-idx | parent-idx | gparent-idx | x | y | z >
# base-idx < 0 ==> no atom
# xyz = 0 ==> no mapping
short2long = np.zeros((20,14,6))
for i in range(20):
i_s, i_l = aa2short[i],aa2long[i]
for j,a in enumerate(i_l):
# case 1: if no atom defined, blank
if (a is None):
short2long[i,j,0] = -1
# case 2: atom is a base atom
elif (a in i_s):
short2long[i,j,0] = i_s.index(a)
if (short2long[i,j,0] == 0):
short2long[i,j,1] = 1
short2long[i,j,2] = 2
else:
short2long[i,j,1] = 0
if (short2long[i,j,0] == 1):
short2long[i,j,2] = 2
else:
short2long[i,j,2] = 1
# case 3: atom is ' O '
elif (a == " O "):
short2long[i,j,0] = 2
short2long[i,j,1] = 0 #Nprev (will pre-roll N as nothing else needs it)
short2long[i,j,2] = 1
short2long[i,j,3:] = np.array(bb2oframe)
# case 4: build this atom
else:
i_f = aa2frames[i]
names = [f[0] for f in i_f]
idx = names.index(a)
short2long[i,j,0] = i_s.index(i_f[idx][1])
short2long[i,j,1] = i_s.index(i_f[idx][2])
short2long[i,j,2] = i_s.index(i_f[idx][3])
short2long[i,j,3:] = np.array(i_f[idx][4])
# build the mapping from atoms in the full rep (Nx14) to the "alternate" rep
long2alt = np.zeros((20,14))
for i in range(20):
i_l, i_lalt = aa2long[i], aa2longalt[i]
for j,a in enumerate(i_l):
if (a is None):
long2alt[i,j] = j
else:
long2alt[i,j] = i_lalt.index(a)
def atoms_from_frames(base,parent,gparent,points):
xs = parent-base
# handle parent==base
mask = (torch.sum(torch.square(xs),dim=-1)==0)
xs[mask,0] = 1.0
xs = xs / torch.norm(xs, dim=-1)[:,None]
ys = gparent-base
ys = ys - torch.sum(xs*ys,dim=-1)[:,None]*xs
# handle gparent==base
mask = (torch.sum(torch.square(ys),dim=-1)==0)
ys[mask,1] = 1.0
ys = ys / torch.norm(ys, dim=-1)[:,None]
zs = torch.cross(xs,ys)
q = torch.stack((xs,ys,zs),dim=2)
retval = base + torch.einsum('nij,nj->ni',q,points)
return retval
#def atoms_from_frames(base,parent,gparent,points):
# xs = parent-base
# # handle parent=base
# mask = (torch.sum(torch.square(xs), dim=-1) == 0)
# xs[mask,0] = 1.0
# xs = xs / torch.norm(xs, dim=-1)[:,None]
#
# ys = gparent-base
# # handle gparent=base
# mask = (torch.sum(torch.square(ys),dim=-1)==0)
# ys[mask,1] = 1.0
#
# ys = ys - torch.sum(xs*ys,dim=-1)[:,None]*xs
# ys = ys / torch.norm(ys, dim=-1)[:,None]
# zs = torch.cross(xs,ys)
# q = torch.stack((xs,ys,zs),dim=2)
# #return base + q@points
# return base + torch.einsum('nij,nj->ni',q,points)
# writepdb
def writepdb(filename, atoms, bfacts, seq):
f = open(filename,"w")
ctr = 1
scpu = seq.cpu()
atomscpu = atoms.cpu()
Bfacts = torch.clamp( bfacts.cpu(), 0, 1)
for i,s in enumerate(scpu):
atms = aa2long[s]
for j,atm_j in enumerate(atms):
if (atm_j is not None):
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, atm_j, num2aa[s],
"A", i+1, atomscpu[i,j,0], atomscpu[i,j,1], atomscpu[i,j,2],
1.0, Bfacts[i] ) )
ctr += 1

View file

@ -0,0 +1,64 @@
import warnings
import dgl
import torch
def to_np(x):
return x.cpu().detach().numpy()
class PickleGraph:
"""Lightweight graph object for easy pickling. Does not support batched graphs."""
def __init__(self, G=None, desired_keys=None):
self.ndata = dict()
self.edata = dict()
if G is None:
self.src = []
self.dst = []
else:
if G.batch_size > 1:
warnings.warn("Copying a batched graph to a PickleGraph is not supported. "
"All node and edge data will be copied, but batching information will be lost.")
self.src, self.dst = (to_np(idx) for idx in G.all_edges())
for k in G.ndata:
if desired_keys is None or k in desired_keys:
self.ndata[k] = to_np(G.ndata[k])
for k in G.edata:
if desired_keys is None or k in desired_keys:
self.edata[k] = to_np(G.edata[k])
def all_edges(self):
return self.src, self.dst
def copy_dgl_graph(G):
if G.batch_size == 1:
src, dst = G.all_edges()
G2 = dgl.DGLGraph((src, dst))
for edge_key in list(G.edata.keys()):
G2.edata[edge_key] = torch.clone(G.edata[edge_key])
for node_key in list(G.ndata.keys()):
G2.ndata[node_key] = torch.clone(G.ndata[node_key])
return G2
else:
list_of_graphs = dgl.unbatch(G)
list_of_copies = []
for batch_G in list_of_graphs:
list_of_copies.append(copy_dgl_graph(batch_G))
return dgl.batch(list_of_copies)
def update_relative_positions(G, *, relative_position_key='d', absolute_position_key='x'):
"""For each directed edge in the graph, calculate the relative position of the destination node with respect
to the source node. Write the relative positions to the graph as edge data."""
src, dst = G.all_edges()
absolute_positions = G.ndata[absolute_position_key]
G.edata[relative_position_key] = absolute_positions[dst] - absolute_positions[src]

View file

@ -0,0 +1,123 @@
import os
import sys
import time
import datetime
import subprocess
import numpy as np
import torch
from utils.utils_data import to_np
_global_log = {}
def try_mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
# @profile
def make_logdir(checkpoint_dir, run_name=None):
if run_name is None:
now = datetime.datetime.now().strftime("%Y_%m_%d_%H.%M.%S")
else:
assert type(run_name) == str
now = run_name
log_dir = os.path.join(checkpoint_dir, now)
try_mkdir(log_dir)
return log_dir
def count_parameters(model):
"""
count number of trainable parameters in module
:param model: nn.Module instance
:return: integer
"""
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
n_params = sum([np.prod(p.size()) for p in model_parameters])
return n_params
def write_info_file(model, FLAGS, UNPARSED_ARGV, wandb_log_dir=None):
time_str = time.strftime("%m%d_%H%M%S")
filename_log = "info_" + time_str + ".txt"
filename_git_diff = "git_diff_" + time_str + ".txt"
checkpoint_name = 'model'
if wandb_log_dir:
log_dir = wandb_log_dir
os.mkdir(os.path.join(log_dir, 'checkpoints'))
checkpoint_path = os.path.join(log_dir, 'checkpoints', checkpoint_name)
elif FLAGS.restore:
# set restore path
assert FLAGS.run_name is not None
log_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.run_name)
checkpoint_path = os.path.join(log_dir, 'checkpoints', checkpoint_name)
else:
# makes logdir with time stamp
log_dir = make_logdir(FLAGS.checkpoint_dir, FLAGS.run_name)
os.mkdir(os.path.join(log_dir, 'checkpoints'))
os.mkdir(os.path.join(log_dir, 'point_clouds'))
# os.mkdir(os.path.join(log_dir, 'train_log'))
# os.mkdir(os.path.join(log_dir, 'test_log'))
checkpoint_path = os.path.join(log_dir, 'checkpoints', checkpoint_name)
# writing arguments and git hash to info file
file = open(os.path.join(log_dir, filename_log), "w")
label = subprocess.check_output(["git", "describe", "--always"]).strip()
file.write('latest git commit on this branch: ' + str(label) + '\n')
file.write('\nFLAGS: \n')
for key in sorted(vars(FLAGS)):
file.write(key + ': ' + str(vars(FLAGS)[key]) + '\n')
# count number of parameters
if hasattr(model, 'parameters'):
file.write('\nNumber of Model Parameters: ' + str(count_parameters(model)) + '\n')
if hasattr(model, 'enc'):
file.write('\nNumber of Encoder Parameters: ' + str(
count_parameters(model.enc)) + '\n')
if hasattr(model, 'dec'):
file.write('\nNumber of Decoder Parameters: ' + str(
count_parameters(model.dec)) + '\n')
file.write('\nUNPARSED_ARGV:\n' + str(UNPARSED_ARGV))
file.write('\n\nBASH COMMAND: \n')
bash_command = 'python'
for argument in sys.argv:
bash_command += (' ' + argument)
file.write(bash_command)
file.close()
# write 'git diff' output into extra file
subprocess.call(["git diff > " + os.path.join(log_dir, filename_git_diff)], shell=True)
return log_dir, checkpoint_path
def log_gradient_norm(tensor, variable_name):
if variable_name not in _global_log:
_global_log[variable_name] = []
def log_gradient_norm_inner(gradient):
gradient_norm = torch.norm(gradient, dim=-1)
_global_log[variable_name].append(to_np(gradient_norm))
tensor.register_hook(log_gradient_norm_inner)
def get_average(variable_name):
if variable_name not in _global_log:
return float('nan')
elif _global_log[variable_name]:
overall_tensor = np.concatenate(_global_log[variable_name])
return np.mean(overall_tensor)
else:
return 0
def clear_data(variable_name):
_global_log[variable_name] = []

View file

@ -0,0 +1,5 @@
try:
profile
except NameError:
def profile(func):
return func

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,42 @@
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import os
import py3Dmol
def execute_pipeline(sequence):
if os.path.isfile(sequence):
with open(sequence, "r") as f:
title = f.readlines()[0][2:]
print(f"Running inference on {title}")
os.system(f"bash run_inference_pipeline.sh {sequence}")
else:
try:
with open("temp_input.fa", "w") as f:
f.write(f"> {sequence[:8]}...\n")
f.write(sequence.upper())
print(f"Running inference on {sequence[:8]}...")
os.system(f"bash run_inference_pipeline.sh temp_input.fa")
except:
print("Unable to run the pipeline.")
raise
def display_pdb(path_to_pdb):
with open(path_to_pdb) as ifile:
protein = "".join([x for x in ifile])
view = py3Dmol.view(width=400, height=300)
view.addModelsAsFrames(protein)
view.setStyle({'model': -1}, {"cartoon": {'color': 'spectrum'}})
view.zoomTo()
view.show()
def cleanup():
os.system("rm t000*")

View file

@ -0,0 +1,4 @@
pandas==1.1.4
scikit-learn==0.24
packaging==21.0
py3Dmol==1.7.0

View file

@ -0,0 +1,78 @@
#!/bin/bash
# make the script stop when error (non-true exit code) is occured
set -e
############################################################
# >>> conda initialize >>>
# !! Contents within this block are managed by 'conda init' !!
__conda_setup="$('conda' 'shell.bash' 'hook' 2> /dev/null)"
eval "$__conda_setup"
unset __conda_setup
# <<< conda initialize <<<
############################################################
SCRIPT=`realpath -s $0`
export PIPEDIR=`dirname $SCRIPT`
CPU="8" # number of CPUs to use
MEM="64" # max memory (in GB)
# Inputs:
IN="$1" # input.fasta
WDIR=`realpath -s $2` # working folder
LEN=`tail -n1 $IN | wc -m`
mkdir -p $WDIR/log
conda activate RoseTTAFold
############################################################
# 1. generate MSAs
############################################################
if [ ! -s $WDIR/t000_.msa0.a3m ]
then
echo "Running HHblits"
$PIPEDIR/input_prep/make_msa.sh $IN $WDIR $CPU $MEM > $WDIR/log/make_msa.stdout 2> $WDIR/log/make_msa.stderr
fi
############################################################
# 2. predict secondary structure for HHsearch run
############################################################
if [ ! -s $WDIR/t000_.ss2 ]
then
echo "Running PSIPRED"
$PIPEDIR/input_prep/make_ss.sh $WDIR/t000_.msa0.a3m $WDIR/t000_.ss2 > $WDIR/log/make_ss.stdout 2> $WDIR/log/make_ss.stderr
fi
############################################################
# 3. search for templates
############################################################
DB="$PIPEDIR/pdb100_2021Mar03/pdb100_2021Mar03"
if [ ! -s $WDIR/t000_.hhr ]
then
echo "Running hhsearch"
HH="hhsearch -b 50 -B 500 -z 50 -Z 500 -mact 0.05 -cpu $CPU -maxmem $MEM -aliw 100000 -e 100 -p 5.0 -d $DB"
cat $WDIR/t000_.ss2 $WDIR/t000_.msa0.a3m > $WDIR/t000_.msa0.ss2.a3m
$HH -i $WDIR/t000_.msa0.ss2.a3m -o $WDIR/t000_.hhr -atab $WDIR/t000_.atab -v 0 > $WDIR/log/hhsearch.stdout 2> $WDIR/log/hhsearch.stderr
fi
############################################################
# 4. end-to-end prediction
############################################################
if [ ! -s $WDIR/t000_.3track.npz ]
then
echo "Running end-to-end prediction"
python $PIPEDIR/network/predict_e2e.py \
-m $PIPEDIR/weights \
-i $WDIR/t000_.msa0.a3m \
-o $WDIR/t000_.e2e \
--hhr $WDIR/t000_.hhr \
--atab $WDIR/t000_.atab \
--db $DB 1> $WDIR/log/network.stdout 2> $WDIR/log/network.stderr
fi
echo "Done"

View file

@ -0,0 +1,17 @@
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import argparse
from pipeline_utils import execute_pipeline
PARSER = argparse.ArgumentParser()
PARSER.add_argument('sequence', type=str, help='A sequence or a path to a FASTA file')
if __name__ == '__main__':
args = PARSER.parse_args()
execute_pipeline(args.sequence)

View file

@ -0,0 +1,76 @@
#!/bin/bash
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# make the script stop when error (non-true exit code) is occurred
set -e
CPU="32" # number of CPUs to use
MEM="64" # max memory (in GB)
# Inputs:
IN="$1" # input.fasta
#WDIR=`realpath -s $2` # working folder
DATABASES_DIR="${2:-/databases}"
WEIGHTS_DIR="${3:-/weights}"
WDIR="${4:-.}"
LEN=`tail -n1 $IN | wc -m`
mkdir -p $WDIR/logs
###########################################################
# 1. generate MSAs
############################################################
if [ ! -s $WDIR/t000_.msa0.a3m ]
then
echo "Running HHblits - looking for the MSAs"
/workspace/rf/input_prep/make_msa.sh $IN $WDIR $CPU $MEM $DATABASES_DIR > $WDIR/logs/make_msa.stdout 2> $WDIR/logs/make_msa.stderr
fi
############################################################
# 2. predict secondary structure for HHsearch run
############################################################
if [ ! -s $WDIR/t000_.ss2 ]
then
echo "Running PSIPRED - looking for the Secondary Structures"
/workspace/rf/input_prep/make_ss.sh $WDIR/t000_.msa0.a3m $WDIR/t000_.ss2 > $WDIR/logs/make_ss.stdout 2> $WDIR/logs/make_ss.stderr
fi
############################################################
# 3. search for templates
############################################################
if [ ! -s $WDIR/t000_.hhr ]
then
echo "Running hhsearch - looking for the Templates"
/workspace/rf/input_prep/prepare_templates.sh $WDIR $CPU $MEM $DATABASES_DIR > $WDIR/logs/prepare_templates.stdout 2> $WDIR/logs/prepare_templates.stderr
fi
############################################################
# 4. end-to-end prediction
############################################################
DB="$DATABASES_DIR/pdb100_2021Mar03/pdb100_2021Mar03"
if [ ! -s $WDIR/t000_.3track.npz ]
then
echo "Running end-to-end prediction"
python /workspace/rf/network/predict_e2e.py \
-m $WEIGHTS_DIR \
-i $WDIR/t000_.msa0.a3m \
-o /results/output.e2e \
--hhr $WDIR/t000_.hhr \
--atab $WDIR/t000_.atab \
--db $DB
fi
echo "Done."
echo "Output saved as /results/output.e2e.pdb"
# 1> $WDIR/log/network.stdout 2> $WDIR/log/network.stderr

View file

@ -0,0 +1,123 @@
#!/bin/bash
# make the script stop when error (non-true exit code) is occured
set -e
############################################################
# >>> conda initialize >>>
# !! Contents within this block are managed by 'conda init' !!
__conda_setup="$('conda' 'shell.bash' 'hook' 2> /dev/null)"
eval "$__conda_setup"
unset __conda_setup
# <<< conda initialize <<<
############################################################
SCRIPT=`realpath -s $0`
export PIPEDIR=`dirname $SCRIPT`
CPU="8" # number of CPUs to use
MEM="64" # max memory (in GB)
# Inputs:
IN="$1" # input.fasta
WDIR=`realpath -s $2` # working folder
LEN=`tail -n1 $IN | wc -m`
mkdir -p $WDIR/log
conda activate RoseTTAFold
############################################################
# 1. generate MSAs
############################################################
if [ ! -s $WDIR/t000_.msa0.a3m ]
then
echo "Running HHblits"
$PIPEDIR/input_prep/make_msa.sh $IN $WDIR $CPU $MEM > $WDIR/log/make_msa.stdout 2> $WDIR/log/make_msa.stderr
fi
############################################################
# 2. predict secondary structure for HHsearch run
############################################################
if [ ! -s $WDIR/t000_.ss2 ]
then
echo "Running PSIPRED"
$PIPEDIR/input_prep/make_ss.sh $WDIR/t000_.msa0.a3m $WDIR/t000_.ss2 > $WDIR/log/make_ss.stdout 2> $WDIR/log/make_ss.stderr
fi
############################################################
# 3. search for templates
############################################################
DB="$PIPEDIR/pdb100_2021Mar03/pdb100_2021Mar03"
if [ ! -s $WDIR/t000_.hhr ]
then
echo "Running hhsearch"
HH="hhsearch -b 50 -B 500 -z 50 -Z 500 -mact 0.05 -cpu $CPU -maxmem $MEM -aliw 100000 -e 100 -p 5.0 -d $DB"
cat $WDIR/t000_.ss2 $WDIR/t000_.msa0.a3m > $WDIR/t000_.msa0.ss2.a3m
$HH -i $WDIR/t000_.msa0.ss2.a3m -o $WDIR/t000_.hhr -atab $WDIR/t000_.atab -v 0 > $WDIR/log/hhsearch.stdout 2> $WDIR/log/hhsearch.stderr
fi
############################################################
# 4. predict distances and orientations
############################################################
if [ ! -s $WDIR/t000_.3track.npz ]
then
echo "Predicting distance and orientations"
python $PIPEDIR/network/predict_pyRosetta.py \
-m $PIPEDIR/weights \
-i $WDIR/t000_.msa0.a3m \
-o $WDIR/t000_.3track \
--hhr $WDIR/t000_.hhr \
--atab $WDIR/t000_.atab \
--db $DB 1> $WDIR/log/network.stdout 2> $WDIR/log/network.stderr
fi
############################################################
# 5. perform modeling
############################################################
mkdir -p $WDIR/pdb-3track
conda deactivate
conda activate folding
for m in 0 1 2
do
for p in 0.05 0.15 0.25 0.35 0.45
do
for ((i=0;i<1;i++))
do
if [ ! -f $WDIR/pdb-3track/model${i}_${m}_${p}.pdb ]; then
echo "python -u $PIPEDIR/folding/RosettaTR.py --roll -r 3 -pd $p -m $m -sg 7,3 $WDIR/t000_.3track.npz $IN $WDIR/pdb-3track/model${i}_${m}_${p}.pdb"
fi
done
done
done > $WDIR/parallel.fold.list
N=`cat $WDIR/parallel.fold.list | wc -l`
if [ "$N" -gt "0" ]; then
echo "Running parallel RosettaTR.py"
parallel -j $CPU < $WDIR/parallel.fold.list > $WDIR/log/folding.stdout 2> $WDIR/log/folding.stderr
fi
############################################################
# 6. Pick final models
############################################################
count=$(find $WDIR/pdb-3track -maxdepth 1 -name '*.npz' | grep -v 'features' | wc -l)
if [ "$count" -lt "15" ]; then
# run DeepAccNet-msa
echo "Running DeepAccNet-msa"
python $PIPEDIR/DAN-msa/ErrorPredictorMSA.py --roll -p $CPU $WDIR/t000_.3track.npz $WDIR/pdb-3track $WDIR/pdb-3track 1> $WDIR/log/DAN_msa.stdout 2> $WDIR/log/DAN_msa.stderr
fi
if [ ! -s $WDIR/model/model_5.crderr.pdb ]
then
echo "Picking final models"
python -u -W ignore $PIPEDIR/DAN-msa/pick_final_models.div.py \
$WDIR/pdb-3track $WDIR/model $CPU > $WDIR/log/pick.stdout 2> $WDIR/log/pick.stderr
echo "Final models saved in: $2/model"
fi
echo "Done"

View file

@ -0,0 +1,33 @@
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
WEIGHTS_DIR="${1:-.}"
DATABASES_DIR="${2:-./databases}"
mkdir -p $DATABASES_DIR
echo "Downloading pre-trained model weights [1G]"
wget https://files.ipd.uw.edu/pub/RoseTTAFold/weights.tar.gz
tar xfz weights.tar.gz -C $WEIGHTS_DIR
# uniref30 [46G]
echo "Downloading UniRef30_2020_06 [46G]"
wget http://wwwuser.gwdg.de/~compbiol/uniclust/2020_06/UniRef30_2020_06_hhsuite.tar.gz -P $DATABASES_DIR
mkdir -p $DATABASES_DIR/UniRef30_2020_06
tar xfz $DATABASES_DIR/UniRef30_2020_06_hhsuite.tar.gz -C $DATABASES_DIR/UniRef30_2020_06
# BFD [272G]
#wget https://bfd.mmseqs.com/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz
#mkdir -p bfd
#tar xfz bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz -C ./bfd
# structure templates (including *_a3m.ffdata, *_a3m.ffindex) [over 100G]
echo "Downloading pdb100_2021Mar03 [over 100G]"
wget https://files.ipd.uw.edu/pub/RoseTTAFold/pdb100_2021Mar03.tar.gz -P $DATABASES_DIR
tar xfz $DATABASES_DIR/pdb100_2021Mar03.tar.gz -C $DATABASES_DIR/

View file

@ -0,0 +1,11 @@
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
PORT="${1:-6006}"
jupyter lab --ip=0.0.0.0 --allow-root --no-browser --NotebookApp.token='' --notebook-dir=/workspace/rf --NotebookApp.allow_origin='*' --port $PORT