Git Product home page Git Product logo

equiformer-pytorch's Introduction

Equiformer - Pytorch (wip)

Implementation of the Equiformer, SE3/E3 equivariant attention network that reaches new SOTA, and adopted for use by EquiFold (Prescient Design) for protein folding

The design of this seems to build off of SE3 Transformers, with the dot product attention replaced with MLP Attention and non-linear message passing from GATv2. It also does a depthwise tensor product for a bit more efficiency. If you think I am mistakened, please feel free to email me.

Will report how it does on my toy task compared to some other networks I am more familiar with, once construction is complete.

This repository may eventually contain an implementation of EquiFold as well, with the added FAPE loss + structural violation checks

Update: The choice of the norm or gating (still need to ablate to figure out which [or both?]) is contributing to the greatly improved results. however, MLP attention does not seem to be doing anything (caveat, still missing a DTP in the values branch, and of course, need to check for bugs). more experiments pending

Update: Nevermind, MLP attention seems to be working, but about the same as dot product attention.

Update: By using the negative of the euclidean distance for dot product of higher types in dot product attention (as done in the coordinate branch of IPA of alphafold2), I now see results that are far better than before as well as MLP attention. My conclusion is that the choice of norm and gating is contributing way more to the results in the paper than MLP attention

Running experiment, denoising residue positions in protein sequence

If you see a signal with this architecture, please consider sharing your findings at OpenBioML

Install

$ pip install equiformer-pytorch

Usage

import torch
from equiformer_pytorch import Equiformer

model = Equiformer(
    num_tokens = 24,
    dim = (4, 4, 2),               # dimensions per type, ascending, length must match number of degrees (num_degrees)
    dim_head = (4, 4, 4),          # dimension per attention head
    heads = (2, 2, 2),             # number of attention heads
    num_linear_attn_heads = 0,     # number of global linear attention heads, can see all the neighbors
    num_degrees = 3,               # number of degrees
    depth = 4,                     # depth of equivariant transformer
    attend_self = True,            # attending to self or not
    reduce_dim_out = True,         # whether to reduce out to dimension of 1, say for predicting new coordinates for type 1 features
    l2_dist_attention = False      # set to False to try out MLP attention
).cuda()

feats = torch.randint(0, 24, (1, 128)).cuda()
coors = torch.randn(1, 128, 3).cuda()
mask  = torch.ones(1, 128).bool().cuda()

out = model(feats, coors, mask) # (1, 128)

out.type0 # invariant type 0    - (1, 128)
out.type1 # equivariant type 1  - (1, 128, 3)

Appreciation

  • StabilityAI for the generous sponsorship, as well as my other sponsors out there

Testing

Tests for spherical harmonics, network equivariance etc

$ python setup.py test

Example

First install sidechainnet

$ pip install sidechainnet

Then run the protein backbone denoising task

$ python denoise.py

Todo

  • move xi and xj separate project and sum logic into Conv class

  • move self interacting key / value production into Conv, fix no pooling in conv with self interaction

  • go with a naive way to split up contribution from input degrees for DTP

  • for dot product attention in higher types, try euclidean distance

  • consider a all-neighbors attention layer just for type0, using linear attention

  • start moving some spherical harmonic stuff to cpp or nim

  • add memory checkpointing for the entire DTP module

  • rotate in the relative distances into type0 dot product, as was done here

Citations

@article{Liao2022EquiformerEG,
    title   = {Equiformer: Equivariant Graph Attention Transformer for 3D Atomistic Graphs},
    author  = {Yi Liao and Tess E. Smidt},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2206.11990}
}
@article {Lee2022.10.07.511322,
    author  = {Lee, Jae Hyeon and Yadollahpour, Payman and Watkins, Andrew and Frey, Nathan C. and Leaver-Fay, Andrew and Ra, Stephen and Cho, Kyunghyun and Gligorijevic, Vladimir and Regev, Aviv and Bonneau, Richard},
    title   = {EquiFold: Protein Structure Prediction with a Novel Coarse-Grained Structure Representation},
    elocation-id = {2022.10.07.511322},
    year    = {2022},
    doi     = {10.1101/2022.10.07.511322},
    publisher = {Cold Spring Harbor Laboratory},
    URL     = {https://www.biorxiv.org/content/early/2022/10/08/2022.10.07.511322},
    eprint  = {https://www.biorxiv.org/content/early/2022/10/08/2022.10.07.511322.full.pdf},
    journal = {bioRxiv}
}
@article{Shazeer2019FastTD,
    title   = {Fast Transformer Decoding: One Write-Head is All You Need},
    author  = {Noam M. Shazeer},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1911.02150}
}
@misc{ding2021cogview,
    title   = {CogView: Mastering Text-to-Image Generation via Transformers},
    author  = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
    year    = {2021},
    eprint  = {2105.13290},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@inproceedings{Kim2020TheLC,
    title   = {The Lipschitz Constant of Self-Attention},
    author  = {Hyunjik Kim and George Papamakarios and Andriy Mnih},
    booktitle = {International Conference on Machine Learning},
    year    = {2020}
}

equiformer-pytorch's People

Contributors

lucidrains avatar hypnopump avatar javierbq avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.