An implementation of Performer, a linear attention-based transformer, in Pytorch

Overview

Performer - Pytorch

PyPI version

An implementation of Performer, a linear attention-based transformer variant with a Fast Attention Via positive Orthogonal Random features approach (FAVOR+).

Install

$ pip install performer-pytorch

Usage

Performer Language Model

import torch
from performer_pytorch import PerformerLM

model = PerformerLM(
    num_tokens = 20000,
    max_seq_len = 2048,             # max sequence length
    dim = 512,                      # dimension
    depth = 12,                     # layers
    heads = 8,                      # heads
    causal = False,                 # auto-regressive or not
    nb_features = 256,              # number of random features, if not set, will default to (d * log(d)), where d is the dimension of each head
    feature_redraw_interval = 1000, # how frequently to redraw the projection matrix, the more frequent, the slower the training
    generalized_attention = False,  # defaults to softmax approximation, but can be set to True for generalized attention
    kernel_fn = nn.ReLU(),          # the kernel function to be used, if generalized attention is turned on, defaults to Relu
    reversible = True,              # reversible layers, from Reformer paper
    ff_chunks = 10,                 # chunk feedforward layer, from Reformer paper
    use_scalenorm = False,          # use scale norm, from 'Transformers without Tears' paper
    use_rezero = False,             # use rezero, from 'Rezero is all you need' paper
    tie_embedding = False,          # multiply final embeddings with token weights for logits, like gpt decoder
    ff_glu = True,                  # use GLU variant for feedforward
    emb_dropout = 0.1,              # embedding dropout
    ff_dropout = 0.1,               # feedforward dropout
    attn_dropout = 0.1,             # post-attn dropout
    local_attn_heads = 4,           # 4 heads are local attention, 4 others are global performers
    local_window_size = 256,        # window size of local attention
    rotary_position_emb = True      # use rotary positional embedding, which endows linear attention with relative positional encoding with no learned parameters. should always be turned on unless if you want to go back to old absolute positional encoding
)

x = torch.randint(0, 20000, (1, 2048))
mask = torch.ones_like(x).bool()

model(x, mask = mask) # (1, 2048, 20000)

Plain Performer, if you are working with say images or other modalities

import torch
from performer_pytorch import Performer

model = Performer(
    dim = 512,
    depth = 1,
    heads = 8,
    causal = True
)

x = torch.randn(1, 2048, 512)
model(x) # (1, 2048, 512)

Encoder / Decoder - Made possible by Thomas Melistas

import torch
from performer_pytorch import PerformerEncDec

SRC_SEQ_LEN = 4096
TGT_SEQ_LEN = 4096
GENERATE_LEN = 512

enc_dec = PerformerEncDec(
    dim = 512,
    tie_token_embed = True,
    enc_num_tokens = 20000,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = SRC_SEQ_LEN,
    dec_num_tokens = 20000,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = TGT_SEQ_LEN,
)

src = torch.randint(0, 20000, (1, SRC_SEQ_LEN))
tgt = torch.randint(0, 20000, (1, TGT_SEQ_LEN))
src_mask = torch.ones_like(src).bool()
tgt_mask = torch.ones_like(src).bool()

# train
enc_dec.train()
loss = enc_dec(src, tgt, enc_mask = src_mask, dec_mask = tgt_mask)
loss.backward()

# generate
generate_in = torch.randint(0, 20000, (1, SRC_SEQ_LEN)).long()
generate_out_prime = torch.tensor([[0.]]).long() # prime with <bos> token
samples = enc_dec.generate(generate_in, generate_out_prime, seq_len = GENERATE_LEN, eos_token = 1) # assume 1 is id of stop token
print(samples.shape) # (1, <= GENERATE_LEN) decode the tokens

Standalone self-attention layer with linear complexity in respect to sequence length, for replacing trained full-attention transformer self-attention layers.

import torch
from performer_pytorch import SelfAttention

attn = SelfAttention(
    dim = 512,
    heads = 8,
    causal = False,
).cuda()

x = torch.randn(1, 1024, 512).cuda()
attn(x) # (1, 1024, 512)

To minimize model surgery, you could also simply rewrite the code, so that the attention step is done by the FastAttention module, as follows.

import torch
from performer_pytorch import FastAttention

# queries / keys / values with heads already split and transposed to first dimension
# 8 heads, dimension of head is 64, sequence length of 512
q = torch.randn(1, 8, 512, 64)
k = torch.randn(1, 8, 512, 64)
v = torch.randn(1, 8, 512, 64)

attn_fn = FastAttention(
    dim_heads = 64,
    nb_features = 256,
    causal = False
)

out = attn_fn(q, k, v) # (1, 8, 512, 64)
# now merge heads and combine outputs with Wo

Advanced

At the end of training, if you wish to fix the projection matrices to get the model to output deterministically, you can invoke the following

model.fix_projection_matrices_()

Now your model will have fixed projection matrices across all layers

Citations

@misc{choromanski2020rethinking,
    title   = {Rethinking Attention with Performers},
    author  = {Krzysztof Choromanski and Valerii Likhosherstov and David Dohan and Xingyou Song and Andreea Gane and Tamas Sarlos and Peter Hawkins and Jared Davis and Afroz Mohiuddin and Lukasz Kaiser and David Belanger and Lucy Colwell and Adrian Weller},
    year    = {2020},
    eprint  = {2009.14794},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@inproceedings{kitaev2020reformer,
    title       = {Reformer: The Efficient Transformer},
    author      = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
    booktitle   = {International Conference on Learning Representations},
    year        = {2020},
    url         = {https://openreview.net/forum?id=rkgNKkHtvB}
}
@inproceedings{katharopoulos_et_al_2020,
    author  = {Katharopoulos, A. and Vyas, A. and Pappas, N. and Fleuret, F.},
    title   = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention},
    booktitle = {Proceedings of the International Conference on Machine Learning (ICML)},
    year    = {2020}
}
@misc{bachlechner2020rezero,
    title   = {ReZero is All You Need: Fast Convergence at Large Depth},
    author  = {Thomas Bachlechner and Bodhisattwa Prasad Majumder and Huanru Henry Mao and Garrison W. Cottrell and Julian McAuley},
    year    = {2020},
    url     = {https://arxiv.org/abs/2003.04887}
}
@article{1910.05895,
    author  = {Toan Q. Nguyen and Julian Salazar},
    title   = {Transformers without Tears: Improving the Normalization of Self-Attention},
    year    = {2019},
    eprint  = {arXiv:1910.05895},
    doi     = {10.5281/zenodo.3525484},
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{roy*2020efficient,
    title   = {Efficient Content-Based Sparse Attention with Routing Transformers},
    author  = {Aurko Roy* and Mohammad Taghi Saffar* and David Grangier and Ashish Vaswani},
    year    = {2020},
    url     = {https://arxiv.org/pdf/2003.05997.pdf}
}
@techreport{zhuiyiroformer,
    title   = {RoFormer: Transformer with Rotary Position Embeddings - ZhuiyiAI},
    author  = {Jianlin Su},
    year    = {2021},
    url     = "https://github.com/ZhuiyiTechnology/roformer",
}
Comments
  • [Feature] EncoderDecoder framework, similar to ReformerEncDec

    [Feature] EncoderDecoder framework, similar to ReformerEncDec

    Hello Phil,

    Nice job on this great architecture. I want to use it as an Encoder Decoder within Deepspeed, so I am thinking of writing a wrapper similar to the one you did for Reformer. Do you have any tips on what to pay attention (no pun intended) and if I need to use padding as in Autopadder?

    Thanks

    opened by gulnazaki 22
  • Causal linear attention benchmark

    Causal linear attention benchmark

    First, thanks for this awesome repo!!

    Based on T5 model classes from Huggingface's transformers, I was trying to use performer attention instead of original T5 attention. We finetuned t5-large with summarization model, and tried to profile both time and memory usage, and compare the performer attention with the original attention. I have only benchmarked with input size of 1024.

    The result clearly showed that performer attention use lot less memory compared to the original transformer. I know from the paper that performer outperforms the original transformer when input size is bigger than 1024. However, finetuning and generation with the performer actually took longer, so I profiled the forward call of both the original T5 attention and the performer attention. The forward of T5 performer took twice longer and the main bottleneck was causal_dot_product_kernel from fast-transformers.

    Is this a normal performace of the performer or causal attention calculation? or Will the performer attention be faster with the bigger input size?

    opened by ice-americano 13
  • Regarding DDP and reversible networks

    Regarding DDP and reversible networks

    Hi, I'm trying to figure out how to combine DDP with setting the network to be reversible.

    My code basically looks like this:

    import pytorch_lightning as pl
    from performer_pytorch import Performer
    ...
    model = nn.Sequential([...,Performer(...,reversible=True)])
    trainer = pl.Trainer(...
                        distributed_backend='ddp',
                        ...)
    trainer.fit(model,train_loader,val_loader)
    

    Now all combinations work for me (ddp/not reversible, not ddp/reversible, not ddp/not reversible) except for ddp and reversible.

    The error I get is:

    RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons:

    1. Use of a module parameter outside the forward function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes
    2. Reused parameters in multiple reentrant backward passes. For example, if you use multiple checkpoint functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases yet.

    I've seen multiple people have similar issues: https://github.com/huggingface/transformers/issues/7160 ,https://github.com/pytorch/pytorch/issues/46166 , https://github.com/tatp22/linformer-pytorch/issues/23

    Do you have any suggestion for how to deal with this issue? Im not really familiar with the inner workings of DDP and the autograd engine, so I'm not sure how to fix this myself.

    opened by Parskatt 11
  • Triangular matrices ?

    Triangular matrices ?

    Does the current implementation provide triangular matrices (to constrain the attention always on the "left" of the sequence, both for input and encoded values) as described in the last section of the original paper?

    opened by jeremycochoy 10
  • wrong implementation for autoregressive self-attention

    wrong implementation for autoregressive self-attention

    Hi, I found that you used fast_transfomers's CUDA Kernel, but it does not contain normalization part, which needs a cumsum outside the CausalDotProduct (in causal_linear_attention). If I didn't miss something, the result of your code should be wrong... But I am not 100% sure.

    opened by Sleepychord 10
  • There are no tests in this project, use_rezero=True is non-functional

    There are no tests in this project, use_rezero=True is non-functional

    Tests are needed to validate that models can train in various configurations. I built and ran simple tests (trying to get authorization to contribute as a PR) and found that use_rezero=True kills the gradient and results in a performer model that cannot learn. The fix consists in initializing the rezero parameter with a small value, but not zero (e.g., 1E-3 works in my tests). Zero prevents any signal to pass through the module so that the parameter will never change from zero.

    opened by fcampagne 10
  • Show what is the performance on enwiki8 is across your projects

    Show what is the performance on enwiki8 is across your projects

    Hello @lucidrains , I´m a very big fan of your work. It is of such as high quality, that every new project you release I get sleepless to try it.

    You do have many different versions of transformers, such as reformer, memory-xl, performer... And apparently you already test it with enwiki8.

    Would be possible to post on Read-me a table with the enwiki runtime, memory and some performance metric? That would be awesome to compare the different implementations.

    Thanks again for your work!!

    opened by bratao 10
  • Issue with biased estimates from QR decomposition

    Issue with biased estimates from QR decomposition

    Hi again :)

    See issue: https://github.com/google-research/google-research/issues/436 that I posted on the main repository. Using the QR incorrectly produces results with significantly higher variance. There is quite an easy fix by simply doing

     q, r = torch.qr(flattened)
     # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
     d = torch.diag(r, 0)
     ph = d.sign()
     q *= ph
    
    opened by Parskatt 9
  • Collaborate on Implementation?

    Collaborate on Implementation?

    I was planning on implementing this on Pytorch as well and started a repo https://github.com/calclavia/Performer-Pytorch Implemented the kernel so far. If the author(s) of this repo wants to collaborate, would be happy to contribute.

    opened by calclavia 9
  • Extra FF when using cross attention

    Extra FF when using cross attention

    Hello Phil,

    I have noticed that when using cross attention a new block (with attention and a FeedForward layer is added), while only an attention layer should be added between the self attention and the FF layer.

    Is there any reason for this?

    opened by gulnazaki 8
  • Add feature_redraw_interval option

    Add feature_redraw_interval option

    This fork allows the user to select a number of forward passes after which the random features will be redrawn. This allows us to avoid doing QR decomposition on the CPU every forward pass. By default it is set to redraw every 1000 passes.

    opened by norabelrose 8
  • Performer Pytorch Slower than Expected and Please Help with Understanding Parameter Count

    Performer Pytorch Slower than Expected and Please Help with Understanding Parameter Count

    Hi,

    First of all, this is a great package from lucidrains and I find it very helpful in my research.

    A quick question is that I noticed ViT-performer is slower than the regular ViT from lucidrains. For example running on mnist from pytorch will take 15 sec/epoch for regular ViT with the configuration below while ViT performer takes 23 sec/epoch.

    Checking the parameter count also shows ViT-performer has double the size of regular ViT.

    Screen Shot 2022-12-12 at 11 32 41 PM Screen Shot 2022-12-12 at 11 28 50 PM

    I am hoping that someone has intuition about the speed of ViT performer vs regular ViT and their parameter counts.

    Thank you very much in advance!

    opened by weihaosong 1
  • Using replicating nn.MultiHeadAttention with multiple performer SelfAttention modules

    Using replicating nn.MultiHeadAttention with multiple performer SelfAttention modules

    As the title says, has anyone tried replacing multi head attention in a typical transformer with the self attention as described in this library.

    my thought was that I can essentially concat the multiple self attention elements together to replicate this per the attached image from the torch website. image

    I'm relatively new to transformers as a whole so hopefully this question makes some sense.

    for reference, considering the interest in a previous post, I've been attempting to explore performer effectiveness with DETR (https://github.com/facebookresearch/detr)

    thanks!

    opened by JGittles 0
  • Question about masking

    Question about masking

    Hi, thanks for the wonderful repo, I am new in BERT, so I 'd like to make sure in your example:

    model = PerformerLM() x = torch.randint(0, 20000, (1, 2048)) mask = torch.ones_like(x).bool() model(x, mask = mask) # (1, 2048, 20000)

    is this 'mask' is attention_mask? i.e., TRUE (1) for normal tokens and FALSE (0) for padding tokens? Or set 1 to indicate padding token? Thanks a lot!

    opened by Microbiods 1
  • Question: Is Performer order equivariant? (can it transform an unordered set of tensors)

    Question: Is Performer order equivariant? (can it transform an unordered set of tensors)

    Hi,

    Thanks for the amazing implementation. I'm wondering if Performer can be used like a set-operator (i.e. whether it is order equivariant)

    For instance, say I have a point cloud and I want to apply self-attention across all the point features. Can Performer be used here (note equivariance: points can be arbitrarily shuffled, but we expect the corresponding transformed features to be identical regardless of the shuffling)?

    Thanks!

    opened by nmakes 0
  • Using Performer with GNNs

    Using Performer with GNNs

    My understanding of "Rethinking Attention with Performers" is that FAVOR+ is used to approximate the attention matrix and avoids the use of the softmax function. In the README.md file, you note that the Plain Performer can be used if we are using images or other modalities, just as the authors elude to Performer's use in other areas.

    I am interested in using Perfomer to approximate attention between nodes in a graph neural network. The graph neural network contains vectors characterizing the node's features and boolean edge indices indicating a connection between two nodes.

    Do you have any recommendations how this is feasible with the current Performer model? I see that Attention.forward() contains input for a mask.

    opened by jah377 0
Releases(1.1.4)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
PyNIF3D is an open-source PyTorch-based library for research on neural implicit functions (NIF)-based 3D geometry representation.

PyNIF3D is an open-source PyTorch-based library for research on neural implicit functions (NIF)-based 3D geometry representation. It aims to accelerate research by providing a modular design that allows for easy extension and combination of NIF-related components, as well as readily available paper implementations and dataset loaders.

Preferred Networks, Inc. 96 Nov 28, 2022
Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Martin Krasser 251 Dec 25, 2022
Tez is a super-simple and lightweight Trainer for PyTorch. It also comes with many utils that you can use to tackle over 90% of deep learning projects in PyTorch.

Tez: a simple pytorch trainer NOTE: Currently, we are not accepting any pull requests! All PRs will be closed. If you want a feature or something does

abhishek thakur 1.1k Jan 4, 2023
null 270 Dec 24, 2022
A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

Fidelity Investments 56 Sep 13, 2022
A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.

null 878 Dec 30, 2022
PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

Cong Cai 12 Dec 19, 2021
A PyTorch implementation of EfficientNet

EfficientNet PyTorch Quickstart Install with pip install efficientnet_pytorch and load a pretrained EfficientNet with: from efficientnet_pytorch impor

Luke Melas-Kyriazi 7.2k Jan 6, 2023
PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf

README TabNet : Attentive Interpretable Tabular Learning This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attent

DreamQuark 2k Dec 27, 2022
PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

glow-pytorch PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

Kim Seonghyeon 433 Dec 27, 2022
This is an differentiable pytorch implementation of SIFT patch descriptor.

This is an differentiable pytorch implementation of SIFT patch descriptor. It is very slow for describing one patch, but quite fast for batch. It can

Dmytro Mishkin 150 Dec 24, 2022
GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks

GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks This repository implements a capsule model Inten

Joel Huang 15 Dec 24, 2022
Tacotron 2 - PyTorch implementation with faster-than-realtime inference

Tacotron 2 (without wavenet) PyTorch implementation of Natural TTS Synthesis By Conditioning Wavenet On Mel Spectrogram Predictions. This implementati

NVIDIA Corporation 4.1k Jan 3, 2023
A Pytorch Implementation for Compact Bilinear Pooling.

CompactBilinearPooling-Pytorch A Pytorch Implementation for Compact Bilinear Pooling. Adapted from tensorflow_compact_bilinear_pooling Prerequisites I

null 169 Dec 23, 2022
A pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch.

Compact Bilinear Pooling for PyTorch. This repository has a pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch. This

Grégoire Payen de La Garanderie 234 Dec 7, 2022
Pytorch implementation of Distributed Proximal Policy Optimization

Pytorch-DPPO Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286 Using PPO with clip loss (from https

Alexis David Jacq 164 Jan 5, 2023
A PyTorch implementation of L-BFGS.

PyTorch-LBFGS: A PyTorch Implementation of L-BFGS Authors: Hao-Jun Michael Shi (Northwestern University) and Dheevatsa Mudigere (Facebook) What is it?

Hao-Jun Michael Shi 478 Dec 27, 2022
A PyTorch implementation of Learning to learn by gradient descent by gradient descent

Intro PyTorch implementation of Learning to learn by gradient descent by gradient descent. Run python main.py TODO Initial implementation Toy data LST

Ilya Kostrikov 300 Dec 11, 2022
PyTorch Implementation of [1611.06440] Pruning Convolutional Neural Networks for Resource Efficient Inference

PyTorch implementation of [1611.06440 Pruning Convolutional Neural Networks for Resource Efficient Inference] This demonstrates pruning a VGG16 based

Jacob Gildenblat 836 Dec 26, 2022