Implementation of LambdaNetworks, a new approach to image recognition that reaches SOTA with less compute


Lambda Networks - Pytorch

Implementation of λ Networks, a new approach to image recognition that reaches SOTA on ImageNet. The new method utilizes λ layer, which captures interactions by transforming contexts into linear functions, termed lambdas, and applying these linear functions to each input separately.

Yannic Kilcher's paper review


$ pip install lambda-networks


Global context

import torch
from lambda_networks import LambdaLayer

layer = LambdaLayer(
    dim = 32,       # channels going in
    dim_out = 32,   # channels out
    n = 64,         # size of the receptive window - max(height, width)
    dim_k = 16,     # key dimension
    heads = 4,      # number of heads, for multi-query
    dim_u = 1       # 'intra-depth' dimension

x = torch.randn(1, 32, 64, 64)
layer(x) # (1, 32, 64, 64)

Localized context

import torch
from lambda_networks import LambdaLayer

layer = LambdaLayer(
    dim = 32,
    dim_out = 32,
    r = 23,         # the receptive field for relative positional encoding (23 x 23)
    dim_k = 16,
    heads = 4,
    dim_u = 4

x = torch.randn(1, 32, 64, 64)
layer(x) # (1, 32, 64, 64)

For fun, you can also import this as follows

from lambda_networks import λLayer

Tensorflow / Keras version

Shinel94 has added a Keras implementation! It won't be officially supported in this repository, so either copy / paste the code under ./lambda_networks/ or make sure to install tensorflow and keras before running the following.

import tensorflow as tf
from lambda_networks.tfkeras import LambdaLayer

layer = LambdaLayer(
    dim_out = 32,
    r = 23,
    dim_k = 16,
    heads = 4,
    dim_u = 1

x = tf.random.normal((1, 64, 64, 16)) # channel last format
layer(x) # (1, 64, 64, 32)


    title={LambdaNetworks: Modeling long-range Interactions without Attention},
    booktitle={Submitted to International Conference on Learning Representations},
    note={under review}
  • Contiguity problem:

    Contiguity problem: "RuntimeError: cuDNN error: CUDNN_STATUS_NOT_SUPPORTED. This error may appear if you passed in a non-contiguous input."

    It seems that LambaLayer breaks contiguity when I try it.

    >> False

    I have to use .contiguous() where I train with it, is it normal?

    opened by Whiax 5
  • Warning: Mixed memory format inputs detected while calling the operator.

    Warning: Mixed memory format inputs detected while calling the operator.

    I have added lambda layers to every block of Resnet, but the following warning will appear. Will it affect the result?

    Warning: Mixed memory format inputs detected while calling the operator. The operator will output channels_last tensor even if some of the inputs are not in channels_last format. (function operator())

    opened by Pluto1314 4
  • Implementation of Lambda convolution

    Implementation of Lambda convolution

    Thanks for the great work in the implementations!

    I would like to ask whether there is a difference in using 'Conv2d' as suggested in Eq. 3 in the paper and your implementation of 'Conv3d'. These two convs treat the (h x w)-dimension as a 1-d sequence and a 2-d image, respectively. I believe they are quite different in concept.

    Point out if I misunderstood.

    Thx a lot.

    opened by romulus0914 3
  • How lambda layer handle the downsample in LambdaResNet?

    How lambda layer handle the downsample in LambdaResNet?

    Hi, Thanks for your clear code, i try to implement the LambdaResNet. Does lambda layer replace all conv2d layer? If so, how does lambda layer handle the downsample in conv2d, like stride=2? Or just keep the conv2d if stride =2, replace only the conv2d layers in stride =1?

    opened by qiaoran-dawnlight 2
  • question about hybrid lambdaResnet

    question about hybrid lambdaResnet


    In the paper, there is this paragraph:

    When working with hybrid LambdaNetworks, we use a single lambda layer in c4 for LambdaResNet50, 3 lambda layers for LambdaResNet101, 6 lambda layers for LambdaResNet-152/200/270/350 and 8 lambda layers for LambdaResNet-420.

    I have several questions about constructing the hybrid lambdaResnet:

    1. Do we only need to replace the 3x3conv with lambda layer in the C4 stage rather than C4 and C5(as in the ablation study)?
    2. When there is more than 1 lambda layers, such as the case of LambdaResNet101, are we replacing the 3x3conv with 3 lambda layers? And in the resnet50 case, we replace the 3x3conv with 1 lambda layers ?
    opened by CoinCheung 2
  • Question: Is there an easy way to visualise lambdas?

    Question: Is there an easy way to visualise lambdas?

    I want to train classifier and tell what regions it pays the most.. well.. attention to :) And make this simultaneously with an inference without using gradcam etc Can I do this?

    opened by lebionick 1
  • Fix relative positional attention (position lambda)

    Fix relative positional attention (position lambda)

    Hi lucidrains, Thanks to your nice work.

    I found an error in the position lambda (relative positional attention) implementation. Relative positional attention, λp, should be translation equivariance, as written in the paper Sec. 3.2. It means that the positional embedding has a constraint, E[n, m] = E[t(n), t(m)], but it is missed in current implementation. This PR fixes it by adding the translation equivariance constraint. I checked that this PR improves the result in my experiment.

    NOTE that this PR modify the function parameter n, from total area (n=w*h) to length of each side (n=w=h).

    opened by khanrc 1
  • Use of Keras Lambda

    Use of Keras Lambda

    Hey! Thank for the awesome implementations :D

    I was wondering why the use of tf.keras.layers.Lambda? Seems unnecessary, regular calls to TF operations works and is more readable.

    You can also call the functional version of the softmax instead.

    opened by cgarciae 1
  • Lambda for a sequence of images

    Lambda for a sequence of images

    Thanks for the quick implementation!

    I have a problem where I have a sequence of images rather than 1. (a video) So instead of having a dimension batch, channels, height, width, I also have after batch a length dimension to determine sequence length.

    Given a known max_length (for the positional embedding), in forward, should conv4d be used instead of 3d to allow interaction between frames?

    In the paper they do mention this could serve as a general framework for sequences of images, so I wonder if you explored that in implementation (where obviously a single image is just a case where length=1)

    opened by AmitMY 1
  • why flops so high

    why flops so high

    I used resnet50, and change C4 layer into LambaBottleNeck; but why flops so high about 20G and input size is 224*244; is that right, or something wrong about my inplementation.

    opened by zisuina 0
  • How to load_model correctly

    How to load_model correctly

    Hi everyone, I am struggling when loading this model saved in .h5 file.

    How is the correct way to load this network? If I use custom_objects I get init() got an unexpected keyword argument 'name'

    opened by JNaranjo-Alcazar 0
  • Please add clarity to code

    Please add clarity to code

    so Phil - I love your work - I wish you could go extra few steps to help out users. I found this class by François-Guillaume @frgfm - which adds in clear math coments. I want to merge it but there's a bit of code drift don't want to introduce any bugs. I beseech you to go extra step to help users bridge from papers to code.

    eg. please articulate return types def forward(self, x: torch.Tensor) -> torch.Tensor:

    Please give any clarity in arguments. # Project input and context to get queries, keys & values

    Throw in some maths as a comment / this is great as it bridges the paper to the code.

    B x (num_heads * dim_k) * H * W -> B x num_heads x dim_k x (H * W)

    import torch
    from torch import nn, einsum
    import torch.nn.functional as F
    from typing import Optional
    __all__ = ['LambdaLayer']
    class LambdaLayer(nn.Module):
        """Lambda layer from `"LambdaNetworks: Modeling long-range interactions without attention"
        <>`_. The implementation was adapted from `lucidrains'
            in_channels (int): input channels
            out_channels (int, optional): output channels
            dim_k (int): key dimension
            n (int, optional): number of input pixels
            r (int, optional): receptive field for relative positional encoding
            num_heads (int, optional): number of attention heads
            dim_u (int, optional): intra-depth dimension
        def __init__(
            in_channels: int,
            out_channels: int,
            dim_k: int,
            n: Optional[int] = None,
            r: Optional[int] = None,
            num_heads: int = 4,
            dim_u: int = 1
        ) -> None:
            self.u = dim_u
            self.num_heads = num_heads
            if out_channels % num_heads != 0:
                raise AssertionError('values dimension must be divisible by number of heads for multi-head query')
            dim_v = out_channels // num_heads
            # Project input and context to get queries, keys & values
            self.to_q = nn.Conv2d(in_channels, dim_k * num_heads, 1, bias=False)
            self.to_k = nn.Conv2d(in_channels, dim_k * dim_u, 1, bias=False)
            self.to_v = nn.Conv2d(in_channels, dim_v * dim_u, 1, bias=False)
            self.norm_q = nn.BatchNorm2d(dim_k * num_heads)
            self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
            self.local_contexts = r is not None
            if r is not None:
                if r % 2 != 1:
                    raise AssertionError('Receptive kernel size should be odd')
                self.padding = r // 2
                self.R = nn.Parameter(torch.randn(dim_k, dim_u, 1, r, r))
                if n is None:
                    raise AssertionError('You must specify the total sequence length (h x w)')
                self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            b, c, h, w = x.shape
            # Project inputs & context to retrieve queries, keys and values
            q = self.to_q(x)
            k = self.to_k(x)
            v = self.to_v(x)
            # Normalize queries & values
            q = self.norm_q(q)
            v = self.norm_v(v)
            # B x (num_heads * dim_k) * H * W -> B x num_heads x dim_k x (H * W)
            q = q.reshape(b, self.num_heads, -1, h * w)
            # B x (dim_k * dim_u) * H * W -> B x dim_u x dim_k x (H * W)
            k = k.reshape(b, -1, self.u, h * w).permute(0, 2, 1, 3)
            # B x (dim_v * dim_u) * H * W -> B x dim_u x dim_v x (H * W)
            v = v.reshape(b, -1, self.u, h * w).permute(0, 2, 1, 3)
            # Normalized keys
            k = k.softmax(dim=-1)
            # Content function
            λc = einsum('b u k m, b u v m -> b k v', k, v)
            Yc = einsum('b h k n, b k v -> b n h v', q, λc)
            # Position function
            if self.local_contexts:
                # B x dim_u x dim_v x (H * W) -> B x dim_u x dim_v x H x W
                v = v.reshape(b, self.u, v.shape[2], h, w)
                λp = F.conv3d(v, self.R, padding=(0, self.padding, self.padding))
                Yp = einsum('b h k n, b k v n -> b n h v', q, λp.flatten(3))
                λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
                Yp = einsum('b h k n, b n k v -> b n h v', q, λp)
            Y = Yc + Yp
            # B x (H * W) x num_heads x dim_v -> B x (num_heads * dim_v) x H x W
            out = Y.permute(0, 2, 3, 1).reshape(b, self.num_heads * v.shape[2], h, w)
            return out
    opened by johndpope 1
  • Image Size

    Image Size

    Are non-square image blocks allowed for context? Using global context and a non-square dimensions (96, 128), I get an error on this line about dimension size.

    λp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, v)

    opened by anklebreaker 0
  • LambdaResNet Implementation?

    LambdaResNet Implementation?

    I have been looking around and found one implementation of LambdaResNets, although there seem to be some metric performance problems and I've found wall-clock performance problems (runs ~7x slower than normal resnets).

    Do you plan on putting out a lambdaresnet model in this repository?

    opened by nollied 4
Phil Wang
Working with Attention. It's all we need.
Phil Wang
PyTorch implementation of TabNet paper :

README TabNet : Attentive Interpretable Tabular Learning This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attent

DreamQuark 2k Dec 27, 2022
An implementation of Performer, a linear attention-based transformer, in Pytorch

Performer - Pytorch An implementation of Performer, a linear attention-based transformer variant with a Fast Attention Via positive Orthogonal Random

Phil Wang 900 Dec 22, 2022
PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

glow-pytorch PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

Kim Seonghyeon 433 Dec 27, 2022
This is an differentiable pytorch implementation of SIFT patch descriptor.

This is an differentiable pytorch implementation of SIFT patch descriptor. It is very slow for describing one patch, but quite fast for batch. It can

Dmytro Mishkin 150 Dec 24, 2022
GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks

GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks This repository implements a capsule model Inten

Joel Huang 15 Dec 24, 2022
Tacotron 2 - PyTorch implementation with faster-than-realtime inference

Tacotron 2 (without wavenet) PyTorch implementation of Natural TTS Synthesis By Conditioning Wavenet On Mel Spectrogram Predictions. This implementati

NVIDIA Corporation 4.1k Jan 3, 2023
A Pytorch Implementation for Compact Bilinear Pooling.

CompactBilinearPooling-Pytorch A Pytorch Implementation for Compact Bilinear Pooling. Adapted from tensorflow_compact_bilinear_pooling Prerequisites I

null 169 Dec 23, 2022
A pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch.

Compact Bilinear Pooling for PyTorch. This repository has a pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch. This

Grégoire Payen de La Garanderie 234 Dec 7, 2022
Pytorch implementation of Distributed Proximal Policy Optimization

Pytorch-DPPO Pytorch implementation of Distributed Proximal Policy Optimization: Using PPO with clip loss (from https

Alexis David Jacq 164 Jan 5, 2023
A PyTorch implementation of L-BFGS.

PyTorch-LBFGS: A PyTorch Implementation of L-BFGS Authors: Hao-Jun Michael Shi (Northwestern University) and Dheevatsa Mudigere (Facebook) What is it?

Hao-Jun Michael Shi 478 Dec 27, 2022
A PyTorch implementation of Learning to learn by gradient descent by gradient descent

Intro PyTorch implementation of Learning to learn by gradient descent by gradient descent. Run python TODO Initial implementation Toy data LST

Ilya Kostrikov 300 Dec 11, 2022
PyTorch Implementation of [1611.06440] Pruning Convolutional Neural Networks for Resource Efficient Inference

PyTorch implementation of [1611.06440 Pruning Convolutional Neural Networks for Resource Efficient Inference] This demonstrates pruning a VGG16 based

Jacob Gildenblat 836 Dec 26, 2022
Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Martin Krasser 251 Dec 25, 2022
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification

STAM - Pytorch Implementation of STAM (Space Time Attention Model), yet another pure and simple SOTA attention model that bests all previous models in

Phil Wang 109 Dec 28, 2022
Fully Automated YouTube Channel ▶️with Added Extra Features.

Fully Automated Youtube Channel ▒█▀▀█ █▀▀█ ▀▀█▀▀ ▀▀█▀▀ █░░█ █▀▀▄ █▀▀ █▀▀█ ▒█▀▀▄ █░░█ ░░█░░ ░▒█░░ █░░█ █▀▀▄ █▀▀ █▄▄▀ ▒█▄▄█ ▀▀▀▀ ░░▀░░ ░▒█░░ ░▀▀▀ ▀▀▀░

sam-sepiol 249 Jan 2, 2023
Official PyTorch implementation of Less is More: Pay Less Attention in Vision Transformers.

Less is More: Pay Less Attention in Vision Transformers Official PyTorch implementation of Less is More: Pay Less Attention in Vision Transformers. By

null 73 Jan 1, 2023
Official repository for "Restormer: Efficient Transformer for High-Resolution Image Restoration". SOTA for motion deblurring, image deraining, denoising (Gaussian/real data), and defocus deblurring.

Restormer: Efficient Transformer for High-Resolution Image Restoration Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan,

Syed Waqas Zamir 906 Dec 30, 2022
Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.

InfoPro-Pytorch The Information Propagation algorithm for training deep networks with local supervision. (ICLR 2021) Revisiting Locally Supervised Lea

null 78 Dec 27, 2022
The source code for the Cutoff data augmentation approach proposed in this paper: "A Simple but Tough-to-Beat Data Augmentation Approach for Natural Language Understanding and Generation".

Cutoff: A Simple Data Augmentation Approach for Natural Language This repository contains source code necessary to reproduce the results presented in

Dinghan Shen 49 Dec 22, 2022
A collection of SOTA Image Classification Models in PyTorch

A collection of SOTA Image Classification Models in PyTorch

sithu3 85 Dec 30, 2022