Implementation of Invariant Point Attention, used for coordinate refinement in the structure module of Alphafold2, as a standalone Pytorch module

Overview

Invariant Point Attention - Pytorch

Implementation of Invariant Point Attention as a standalone module, which was used in the structure module of Alphafold2 for coordinate refinement.

  • write up a test for invariance under rotation
  • enforce float32 for certain operations

Install

$ pip install invariant-point-attention

Usage

import torch
from einops import repeat
from invariant_point_attention import InvariantPointAttention

attn = InvariantPointAttention(
    dim = 64,                  # single (and pairwise) representation dimension
    heads = 8,                 # number of attention heads
    scalar_key_dim = 16,       # scalar query-key dimension
    scalar_value_dim = 16,     # scalar value dimension
    point_key_dim = 4,         # point query-key dimension
    point_value_dim = 4        # point value dimension
)

single_repr   = torch.randn(1, 256, 64)      # (batch x seq x dim)
pairwise_repr = torch.randn(1, 256, 256, 64) # (batch x seq x seq x dim)
mask          = torch.ones(1, 256).bool()    # (batch x seq)

rotations     = repeat(torch.eye(3), '... -> b n ...', b = 1, n = 256)  # (batch x seq x rot1 x rot2) - example is identity
translations  = torch.zeros(1, 256, 3) # translation, also identity for example

attn_out = attn(
    single_repr,
    pairwise_repr,
    rotations = rotations,
    translations = translations,
    mask = mask
)

attn_out.shape # (1, 256, 64)

You can also use this module without the pairwise representations, which is very specific to the Alphafold2 architecture.

import torch
from einops import repeat
from invariant_point_attention import InvariantPointAttention

attn = InvariantPointAttention(
    dim = 64,
    heads = 8,
    require_pairwise_repr = False   # set this to False to use the module without pairwise representations
)

seq           = torch.randn(1, 256, 64)
mask          = torch.ones(1, 256).bool()

rotations     = repeat(torch.eye(3), '... -> b n ...', b = 1, n = 256)
translations  = torch.randn(1, 256, 3)

attn_out = attn(
    seq,
    rotations = rotations,
    translations = translations,
    mask = mask
)

attn_out.shape # (1, 256, 64)

You can also use one IPA-based transformer block, which is an IPA followed by a feedforward. By default it will use post-layernorm as done in the official code, but you can also try pre-layernorm by setting post_norm = False

import torch
from torch import nn
from einops import repeat
from invariant_point_attention import IPABlock

block = IPABlock(
    dim = 64,
    heads = 8,
    scalar_key_dim = 16,
    scalar_value_dim = 16,
    point_key_dim = 4,
    point_value_dim = 4
)

seq           = torch.randn(1, 256, 64)
pairwise_repr = torch.randn(1, 256, 256, 64)
mask          = torch.ones(1, 256).bool()

rotations     = repeat(torch.eye(3), 'r1 r2 -> b n r1 r2', b = 1, n = 256)
translations  = torch.randn(1, 256, 3)

block_out = block(
    seq,
    pairwise_repr = pairwise_repr,
    rotations = rotations,
    translations = translations,
    mask = mask
)

updates = nn.Linear(64, 6)(block_out)
quaternion_update, translation_update = updates.chunk(2, dim = -1) # (1, 256, 3), (1, 256, 3)

# apply updates to rotations and translations for the next iteration

Citations

@Article{AlphaFold2021,
    author  = {Jumper, John and Evans, Richard and Pritzel, Alexander and Green, Tim and Figurnov, Michael and Ronneberger, Olaf and Tunyasuvunakool, Kathryn and Bates, Russ and {\v{Z}}{\'\i}dek, Augustin and Potapenko, Anna and Bridgland, Alex and Meyer, Clemens and Kohl, Simon A A and Ballard, Andrew J and Cowie, Andrew and Romera-Paredes, Bernardino and Nikolov, Stanislav and Jain, Rishub and Adler, Jonas and Back, Trevor and Petersen, Stig and Reiman, David and Clancy, Ellen and Zielinski, Michal and Steinegger, Martin and Pacholska, Michalina and Berghammer, Tamas and Bodenstein, Sebastian and Silver, David and Vinyals, Oriol and Senior, Andrew W and Kavukcuoglu, Koray and Kohli, Pushmeet and Hassabis, Demis},
    journal = {Nature},
    title   = {Highly accurate protein structure prediction with {AlphaFold}},
    year    = {2021},
    doi     = {10.1038/s41586-021-03819-2},
    note    = {(Accelerated article preview)},
}
Comments
  • Computing point dist - use cartesian dimension instead of hidden dimension

    Computing point dist - use cartesian dimension instead of hidden dimension

    https://github.com/lucidrains/invariant-point-attention/blob/2f1fb7ca003d9c94d4144d1f281f8cbc914c01c2/invariant_point_attention/invariant_point_attention.py#L130

    I think it should be dim=-1, thus using the cartesian (xyz) axis, rather than dim=-2, which uses the hidden dimension.

    opened by aced125 3
  • In-place rotation detach not allowed

    In-place rotation detach not allowed

    Hi, this is probably highly version-dependent (I have pytorch=1.11.0, pytorch3d=0.7.0 nightly), but I thought I'd report it. Torch doesn't like the in-place detach of the rotation tensor. Full stack trace (from denoise.py):

    Traceback (most recent call last):
      File "denoise.py", line 56, in <module>
        denoised_coords = net(
      File "/home/pi-user/miniconda3/envs/piai/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/pi-user/invariant-point-attention/invariant_point_attention/invariant_point_attention.py", line 336, in forward
        rotations.detach_()
    RuntimeError: Can't detach views in-place. Use detach() instead. If you are using DistributedDataParallel (DDP) for training, and gradient_as_bucket_view is set as True, gradients are views of DDP buckets, and hence detach_() cannot be called on these gradients. To fix this error, please refer to the Optimizer.zero_grad() function in torch/optim/optimizer.py as the solution.
    

    Switching to rotations = rotations.detach() seems to behave correctly (tested in denoise.py and my own code). I'm not totally sure if this allocates a separate tensor, or just creates a new node pointing to the same data.

    opened by sidnarayanan 1
  • Report a bug that causes instability in training

    Report a bug that causes instability in training

    Hi, I would like to report a bug in the rotation, that causes instability in training. https://github.com/lucidrains/invariant-point-attention/blob/de337568959eb7611ba56eace2f642ca41e26216/invariant_point_attention/invariant_point_attention.py#L322

    The IPA Transformer is similar to the structure module in AF2, where the recycling is used. Note that we usually detach the gradient of rotation, which may causes instability during training. The reason is that the gradient of rotation would update the rotation during back propagation, which results in the instability based on experiments. Therefore we usually detach the rotation to dispel the updating effect of gradient descent. I have seen you do this in your alphafold2 repo (https://github.com/lucidrains/alphafold2).

    If you think this is a problem, please let me know. I am happy to submit a pr to fix that.

    Best, Zhangzhi Peng

    opened by pengzhangzhi 1
  • Subtle mistake in the implementation

    Subtle mistake in the implementation

    Hi. Thanks for your implementation. It is very helpful. However, I find that you miss the dropout in the IPAModule.

    https://github.com/lucidrains/invariant-point-attention/blob/de337568959eb7611ba56eace2f642ca41e26216/invariant_point_attention/invariant_point_attention.py#L239

    In the alphafold2 supplementary, the dropout is nested in the layer norm, which also holds true in the layer norm at transition layer (line 9 in the figure below). image

    If you think this is a problem, please let me know. I will submit a pr to fix it. Thanks again for sharing such an amazing repo.

    Best, Zhangzhi Peng

    opened by pengzhangzhi 1
  • change quaternions update as original alphafold2

    change quaternions update as original alphafold2

    In the original alphafold2 IPA module, pure-quaternion (without real part) description is used for quaternion update. This can be broken down to the residual-update-like formulation. But in this code you use (1, a, b, c) style quaternion so I believe the quaternion update should be done as a simple multiply update. As far as I have tested, the loss seems to go down more efficiently with the modification.

    opened by ShintaroMinami 1
  • #126 maybe omit the 'self.point_attn_logits_scale'?

    #126 maybe omit the 'self.point_attn_logits_scale'?

    Hi luci:

    I read the original paper and compare it to your implement, found one place might be some mistake:

    #126. attn_logits_points = -0.5 * (point_dist * point_weights).sum(dim = -1),

    I thought it should be attn_logits_points = -0.5 * (point_dist * point_weights * self.point_attn_logits_scale).sum(dim = -1)

    Thanks for your sharing!

    opened by CiaoHe 1
  • Application of Invariant point attention : preserver part of structure.

    Application of Invariant point attention : preserver part of structure.

    Hi, lucidrian. First of all really thanks for your work!

    I have a question, how can I change(denoise) the structure only in the region I want, how do I do it? (denoise.py)

    opened by hw-protein 0
  • Equivariance test for IPA Transformer

    Equivariance test for IPA Transformer

    @lucidrains I would like to ask about the equivariance of the transformer (not IPA blocks). I wonder if you checked for the equivariance of the output when you allow the transformation of local points to global points using the updated quaternions and translations. I am not sure why this test fails in my case.

    opened by amrhamedp 1
Owner
Phil Wang
Working with Attention
Phil Wang
Implementation of the GVP-Transformer, which was used in the paper "Learning inverse folding from millions of predicted structures" for de novo protein design alongside Alphafold2

GVP Transformer (wip) Implementation of the GVP-Transformer, which was used in the paper Learning inverse folding from millions of predicted structure

Phil Wang 19 May 6, 2022
Implementation of Graph Transformer in Pytorch, for potential use in replicating Alphafold2

Graph Transformer - Pytorch Implementation of Graph Transformer in Pytorch, for potential use in replicating Alphafold2. This was recently used by bot

Phil Wang 97 Dec 28, 2022
Code for our CVPR2021 paper coordinate attention

Coordinate Attention for Efficient Mobile Network Design (preprint) This repository is a PyTorch implementation of our coordinate attention (will appe

Qibin (Andrew) Hou 726 Jan 5, 2023
An SE(3)-invariant autoencoder for generating the periodic structure of materials

Crystal Diffusion Variational AutoEncoder This software implementes Crystal Diffusion Variational AutoEncoder (CDVAE), which generates the periodic st

Tian Xie 94 Dec 10, 2022
RefineGNN - Iterative refinement graph neural network for antibody sequence-structure co-design (RefineGNN)

Iterative refinement graph neural network for antibody sequence-structure co-des

Wengong Jin 83 Dec 31, 2022
Code for paper "ASAP-Net: Attention and Structure Aware Point Cloud Sequence Segmentation"

ASAP-Net This project implements ASAP-Net of paper ASAP-Net: Attention and Structure Aware Point Cloud Sequence Segmentation (BMVC2020). Overview We i

Hanwen Cao 26 Aug 25, 2022
A modified version of DeepMind's Alphafold2 to divide CPU part (MSA and template searching) and GPU part (prediction model)

ParallelFold Author: Bozitao Zhong This is a modified version of DeepMind's Alphafold2 to divide CPU part (MSA and template searching) and GPU part (p

Bozitao Zhong 77 Dec 22, 2022
PRIN/SPRIN: On Extracting Point-wise Rotation Invariant Features

PRIN/SPRIN: On Extracting Point-wise Rotation Invariant Features Overview This repository is the Pytorch implementation of PRIN/SPRIN: On Extracting P

Yang You 17 Mar 2, 2022
Point Cloud Denoising input segmentation output raw point-cloud valid/clear fog rain de-noised Abstract Lidar sensors are frequently used in environme

Point Cloud Denoising input segmentation output raw point-cloud valid/clear fog rain de-noised Abstract Lidar sensors are frequently used in environme

null 75 Nov 24, 2022
A PyTorch Implementation of Single Shot Scale-invariant Face Detector.

S³FD: Single Shot Scale-invariant Face Detector A PyTorch Implementation of Single Shot Scale-invariant Face Detector. Eval python wider_eval_pytorch.

carwin 235 Jan 7, 2023
PyTorch 1.5 implementation for paper DECOR-GAN: 3D Shape Detailization by Conditional Refinement.

DECOR-GAN PyTorch 1.5 implementation for paper DECOR-GAN: 3D Shape Detailization by Conditional Refinement, Zhiqin Chen, Vladimir G. Kim, Matthew Fish

Zhiqin Chen 72 Dec 31, 2022
PyTorch Implementation of Google Brain's WaveGrad 2: Iterative Refinement for Text-to-Speech Synthesis

WaveGrad2 - PyTorch Implementation PyTorch Implementation of Google Brain's WaveGrad 2: Iterative Refinement for Text-to-Speech Synthesis. Status (202

Keon Lee 59 Dec 6, 2022
Pytorch implementation of “Recursive Non-Autoregressive Graph-to-Graph Transformer for Dependency Parsing with Iterative Refinement”

Graph-to-Graph Transformers Self-attention models, such as Transformer, have been hugely successful in a wide range of natural language processing (NL

Idiap Research Institute 40 Aug 14, 2022
Photographic Image Synthesis with Cascaded Refinement Networks - Pytorch Implementation

Photographic Image Synthesis with Cascaded Refinement Networks-Pytorch (https://arxiv.org/abs/1707.09405) This is a Pytorch implementation of cascaded

Soumya Tripathy 63 Mar 27, 2022
Unoffical implementation about Image Super-Resolution via Iterative Refinement by Pytorch

Image Super-Resolution via Iterative Refinement Paper | Project Brief This is a unoffical implementation about Image Super-Resolution via Iterative Re

LiangWei Jiang 2.5k Jan 2, 2023
This repository implements and evaluates convolutional networks on the Möbius strip as toy model instantiations of Coordinate Independent Convolutional Networks.

Orientation independent Möbius CNNs This repository implements and evaluates convolutional networks on the Möbius strip as toy model instantiations of

Maurice Weiler 59 Dec 9, 2022
Progressive Coordinate Transforms for Monocular 3D Object Detection

Progressive Coordinate Transforms for Monocular 3D Object Detection This repository is the official implementation of PCT. Introduction In this paper,

null 58 Nov 6, 2022
Implementation of the "PSTNet: Point Spatio-Temporal Convolution on Point Cloud Sequences" paper.

PSTNet: Point Spatio-Temporal Convolution on Point Cloud Sequences Introduction Point cloud sequences are irregular and unordered in the spatial dimen

Hehe Fan 63 Dec 9, 2022
Implementation of the "Point 4D Transformer Networks for Spatio-Temporal Modeling in Point Cloud Videos" paper.

Point 4D Transformer Networks for Spatio-Temporal Modeling in Point Cloud Videos Introduction Point cloud videos exhibit irregularities and lack of or

Hehe Fan 101 Dec 29, 2022