An implementation of the efficient attention module.

Overview

Efficient Attention

An implementation of the efficient attention module.

Description

Efficient attention is an attention mechanism that substantially optimizes the memory and computational efficiency while retaining exactly the same expressive power as the conventional dot-product attention. The illustration above compares the two types of attention. The efficient attention module is a drop-in replacement for the non-local module (Wang et al., 2018), while it:

  • uses less resources to achieve the same accuracy;
  • achieves higher accuracy with the same resource constraints (by allowing more insertions); and
  • is applicable in domains and models where the non-local module is not (due to resource constraints).

Resources

YouTube:

bilibili (for users in Mainland China):

Implementation details

This repository implements the efficient attention module with softmax normalization, output reprojection, and residual connection.

Features not in the paper

This repository implements additionally implements the multi-head mechanism which was not in the paper. To learn more about the mechanism, refer to Vaswani et al.

Citation

The paper will appear at WACV 2021. If you use, compare with, or refer to this work, please cite

@inproceedings{shen2021efficient,
    author = {Zhuoran Shen and Mingyuan Zhang and Haiyu Zhao and Shuai Yi and Hongsheng Li},
    title = {Efficient Attention: Attention with Linear Complexities},
    booktitle = {WACV},
    year = {2021},
}
Comments
  • How to use efficient attention class

    How to use efficient attention class

    image

    I'm trying to understand how to use your attention module based on the figure above and the code below.

    From what I understand from the non-local paper, if I have an input feature of m_batchsize, channels, height, width = input_.size(),

    then n = m_batchsize*height*width and d = channels.

    So in the code below, I should use channels = in_channels, key_channels, value_channels.

    But what should the head_countbe? Should it be divisible by the number of channels?

    import torch
    from torch import nn
    from torch.nn import functional as f
    
    class EfficientAttention(nn.Module):
        
        def __init__(self, in_channels, key_channels, head_count, value_channels):
            super().__init__()
            self.in_channels = in_channels
            self.key_channels = key_channels
            self.head_count = head_count
            self.value_channels = value_channels
    
            self.keys = nn.Conv2d(in_channels, key_channels, 1)
            self.queries = nn.Conv2d(in_channels, key_channels, 1)
            self.values = nn.Conv2d(in_channels, value_channels, 1)
            self.reprojection = nn.Conv2d(value_channels, in_channels, 1)
    
        def forward(self, input_):
            n, _, h, w = input_.size()
            keys = self.keys(input_).reshape((n, self.key_channels, h * w))
            queries = self.queries(input_).reshape(n, self.key_channels, h * w)
            values = self.values(input_).reshape((n, self.value_channels, h * w))
            head_key_channels = self.key_channels // self.head_count
            head_value_channels = self.value_channels // self.head_count
            
            attended_values = []
            for i in range(self.head_count):
                key = f.softmax(keys[
                    :,
                    i * head_key_channels: (i + 1) * head_key_channels,
                    :
                ], dim=2)
                query = f.softmax(queries[
                    :,
                    i * head_key_channels: (i + 1) * head_key_channels,
                    :
                ], dim=1)
                value = values[
                    :,
                    i * head_value_channels: (i + 1) * head_value_channels,
                    :
                ]
                context = key @ value.transpose(1, 2)
                attended_value = (
                    context.transpose(1, 2) @ query
                ).reshape(n, head_value_channels, h, w)
                attended_values.append(attended_value)
    
            aggregated_values = torch.cat(attended_values, dim=1)
            reprojected_value = self.reprojection(aggregated_values)
            attention = reprojected_value + input_
    
            return attention
    
    opened by chandlerbing65nm 7
  • About module parameters

    About module parameters

    Hi, it's a great job. I'm a rookie for DL.

    The EfficientAttention have four parameters: in_channels, key_channels, head_count, value_channels. I know key_channels = d_k, value_channels = d_v. What are head_count means? What is a common setting for four parameters?

    opened by TRillionZxY 4
  • About Normalization

    About Normalization

    After I ran the code you provided, I compared it with the traditional self-attention mechanism. I found the results are quite different from traditional ones which are truly normalized, but the efficient one is not. Can you offer some help?

    opened by VoyageWang 3
  • in_channels, key_channels, head_count, value_channels

    in_channels, key_channels, head_count, value_channels

    i am trying to use the efficient attention. I am confused about the params. I change the key_channels and the values, but the output dimension keeps the same, why?

    opened by feimadada 2
  • How to apply scaled dot-product attention by efficient attention?

    How to apply scaled dot-product attention by efficient attention?

    Hi,

    Would you please tell me how to use efficient attention to be similar to the scaled dot-prodcut attention?

    I notice that you apply softmax for both query and key. So is it right to set temperature = (d_k) ** 0.25 and apply key = f.softmax(keys[ :, i * head_key_channels: (i + 1) * head_key_channels, : ] / temperature, dim=2) and query = f.softmax(queries[ :, i * head_key_channels: (i + 1) * head_key_channels, : ] / temperature, dim=1) for both query and key in your code to make it similar to the scaled dot-product attention?

    Thank you!

    opened by wangyue7777 2
  • Shape of attention map

    Shape of attention map

    Hi, I found the attention map computed using the script is C x C. Shouldn't it be (H x W) x (H x W) if we want spatial attention?

    Thank you for any information that you can provide.

    opened by aarontyliu 1
  • Can you apply masks in this attention model?

    Can you apply masks in this attention model?

    In Seq2Seq models, it is common to apply mask to remove paddings, or to remove future inputs of a causal model. Is it possible to do so in Efficient Attention, as it does not have a key seq to query seq mapping?

    opened by rongcuid 1
  • Different query positions on the same image

    Different query positions on the same image

    Hello, Thanks for the excellent work. I would like to show different attention maps produced using different query positions (reference bouding boxes). For example when there are three bouding boxes on the image I would like to calculate the attention w.r.t. them and show the difference. Is it possible to do this?
    Screenshot from 2021-05-05 10-32-24

    opened by horanyinora 1
  • Question about the paper - what is PSMNet (baseline)?

    Question about the paper - what is PSMNet (baseline)?

    Hi,

    I'm slightly confused as to what the difference between these two in figure 7 is in the paper? PSMNet (original) 1.09 0 PSMNet (baseline) 0.51 0

    Thanks, Oliver

    opened by oliver-batchelor 1
Owner
Shen Zhuoran
Software Engineer @ponyai (Pony.ai). Alumnus of HKU and @google AI Residency, Diamond player of StarCraft II.
Shen Zhuoran
计算机视觉中用到的注意力模块和其他即插即用模块PyTorch Implementation Collection of Attention Module and Plug&Play Module

PyTorch实现多种计算机视觉中网络设计中用到的Attention机制,还收集了一些即插即用模块。由于能力有限精力有限,可能很多模块并没有包括进来,有任何的建议或者改进,可以提交issue或者进行PR。

PJDong 599 Dec 23, 2022
Implementation of Invariant Point Attention, used for coordinate refinement in the structure module of Alphafold2, as a standalone Pytorch module

Invariant Point Attention - Pytorch Implementation of Invariant Point Attention as a standalone module, which was used in the structure module of Alph

Phil Wang 113 Jan 5, 2023
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Phil Wang 272 Dec 23, 2022
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification

STAM - Pytorch Implementation of STAM (Space Time Attention Model), yet another pure and simple SOTA attention model that bests all previous models in

Phil Wang 109 Dec 28, 2022
Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding

Relational Self-Attention: What's Missing in Attention for Video Understanding This repository is the official implementation of "Relational Self-Atte

mandos 43 Dec 7, 2022
Official implementation of cosformer-attention in cosFormer: Rethinking Softmax in Attention

cosFormer Official implementation of cosformer-attention in cosFormer: Rethinking Softmax in Attention Update log 2022/2/28 Add core code License This

null 120 Dec 15, 2022
Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"

Deformable Attention Implementation of Deformable Attention from this paper in Pytorch, which appears to be an improvement to what was proposed in DET

Phil Wang 128 Dec 24, 2022
Implementation of the Triangle Multiplicative module, used in Alphafold2 as an efficient way to mix rows or columns of a 2d feature map, as a standalone package for Pytorch

Triangle Multiplicative Module - Pytorch Implementation of the Triangle Multiplicative module, used in Alphafold2 as an efficient way to mix rows or c

Phil Wang 22 Oct 28, 2022
Efficient-GlobalPointer - Pytorch Efficient GlobalPointer

引言 感谢苏神带来的模型,原文地址:https://spaces.ac.cn/archives/8877 如何运行 对应模型EfficientGlobalPoi

powerycy 40 Dec 14, 2022
EPSANet:An Efficient Pyramid Split Attention Block on Convolutional Neural Network

EPSANet:An Efficient Pyramid Split Attention Block on Convolutional Neural Network This repo contains the official Pytorch implementaion code and conf

Hu Zhang 175 Jan 7, 2023
Efficient Conformer: Progressive Downsampling and Grouped Attention for Automatic Speech Recognition

Efficient Conformer: Progressive Downsampling and Grouped Attention for Automatic Speech Recognition Official implementation of the Efficient Conforme

Maxime Burchi 145 Dec 30, 2022
Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch

Memory Efficient Attention This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch. Implementation is almo

Amin Rezaei 126 Dec 27, 2022
PyTorch code for our paper "Attention in Attention Network for Image Super-Resolution"

Under construction... Attention in Attention Network for Image Super-Resolution (A2N) This repository is an PyTorch implementation of the paper "Atten

Haoyu Chen 71 Dec 30, 2022
Attention-driven Robot Manipulation (ARM) which includes Q-attention

Attention-driven Robotic Manipulation (ARM) This codebase is home to: Q-attention: Enabling Efficient Learning for Vision-based Robotic Manipulation I

Stephen James 84 Dec 29, 2022
Locally Enhanced Self-Attention: Rethinking Self-Attention as Local and Context Terms

LESA Introduction This repository contains the official implementation of Locally Enhanced Self-Attention: Rethinking Self-Attention as Local and Cont

Chenglin Yang 20 Dec 31, 2021
Local Attention - Flax module for Jax

Local Attention - Flax Autoregressive Local Attention - Flax module for Jax Install $ pip install local-attention-flax Usage from jax import random fr

Phil Wang 16 Jun 16, 2022
A highly efficient and modular implementation of Gaussian Processes in PyTorch

GPyTorch GPyTorch is a Gaussian process library implemented using PyTorch. GPyTorch is designed for creating scalable, flexible, and modular Gaussian

null 3k Jan 2, 2023
Official pytorch implementation of paper "Inception Convolution with Efficient Dilation Search" (CVPR 2021 Oral).

IC-Conv This repository is an official implementation of the paper Inception Convolution with Efficient Dilation Search. Getting Started Download Imag

Jie Liu 111 Dec 31, 2022