Attention for PyTorch with Linear Memory Footprint
Unofficially implements https://arxiv.org/abs/2112.05682 to get Linear Memory Cost on Attention (+ some sidekick speedup on the GPU when compared to reference implementation in JAX
)
Usage:
git clone https://github.com/CHARM-Tx/linear_mem_attention_pytorch
cd linear_mem_attention_pytorch
python setup.py install
Usage:
High Level
from linear_mem_attention_torch.fast_attn import Attention
batch, length, features = 2, 2**8, 64
x, ctx = torch.randn(2, batch, length, features)
mask = torch.randn(batch, length) < 1.
attn = Attention(dim=features, heads = 8, dim_head = 64, bias=False)
# self-attn
v_self = attn(x, x, mask, query_chunk_size=1024, key_chunk_size=4096)
# cross-attn
v_cross = attn(x, ctx, mask, query_chunk_size=1024, key_chunk_size=4096)
Low level
from linear_mem_attention_torch import attention
batch, length, heads, features = 2, 2**8, 8, 64
mask = torch.randn(batch, length) < 1.
q, k, v = torch.randn(3, batch, length, heads, features)
v_ = attention(q, k, v, mask, query_chunk_size=1024, key_chunk_size=4096)
Benchmarks
See examples/example_benchamrk.ipynb
for more information.
Citations:
@misc{rabe2021selfattention,
title={Self-attention Does Not Need $O(n^2)$ Memory},
author={Markus N. Rabe and Charles Staats},
year={2021},
eprint={2112.05682},
archivePrefix={arXiv},
primaryClass={cs.LG}
}