Global Filter Networks for Image Classification

Overview

Global Filter Networks for Image Classification

Created by Yongming Rao, Wenliang Zhao, Zheng Zhu, Jiwen Lu, Jie Zhou

This repository contains PyTorch implementation for GFNet.

Global Filter Networks is a transformer-style architecture that learns long-term spatial dependencies in the frequency domain with log-linear complexity. Our architecture replaces the self-attention layer in vision transformers with three key operations: a 2D discrete Fourier transform, an element-wise multiplication between frequency-domain features and learnable global filters, and a 2D inverse Fourier transform.

intro

Our code is based on pytorch-image-models and DeiT.

[Project Page] [arXiv]

Global Filter Layer

GFNet is a conceptually simple yet computationally efficient architecture, which consists of several stacking Global Filter Layers and Feedforward Networks (FFN). The Global Filter Layer mixes tokens with log-linear complexity benefiting from the highly efficient Fast Fourier Transform (FFT) algorithm. The layer is easy to implement:

import torch
import torch.nn as nn
import torch.fft

class GlobalFilter(nn.Module):
    def __init__(self, dim, h=14, w=8):
        super().__init__()
        self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
        self.w = w
        self.h = h

    def forward(self, x):
        B, H, W, C = x.shape
        x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
        weight = torch.view_as_complex(self.complex_weight)
        x = x * weight
        x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
        return x

Compared to self-attention and spatial MLP, our Global Filter Layer is much more efficient to process high-resolution feature maps:

efficiency

Model Zoo

We provide our GFNet models pretrained on ImageNet:

name arch Params FLOPs acc@1 acc@5 url
GFNet-Ti gfnet-ti 7M 1.3G 74.6 92.2 Tsinghua Cloud / Google Drive
GFNet-XS gfnet-xs 16M 2.8G 78.6 94.2 Tsinghua Cloud / Google Drive
GFNet-S gfnet-s 25M 4.5G 80.0 94.9 Tsinghua Cloud / Google Drive
GFNet-B gfnet-b 43M 7.9G 80.7 95.1 Tsinghua Cloud / Google Drive
GFNet-H-Ti gfnet-h-ti 15M 2.0G 80.1 95.1 Tsinghua Cloud / Google Drive
GFNet-H-S gfnet-h-s 32M 4.5G 81.5 95.6 Tsinghua Cloud / Google Drive
GFNet-H-B gfnet-h-b 54M 8.4G 82.9 96.2 Tsinghua Cloud / Google Drive

Usage

Requirements

  • torch>=1.8.1
  • torchvision
  • timm

Data preparation: download and extract ImageNet images from http://image-net.org/. The directory structure should be

│ILSVRC2012/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Evaluation

To evaluate a pre-trained GFNet model on the ImageNet validation set with a single GPU, run:

python infer.py --data-path /path/to/ILSVRC2012/ --arch arch_name --path /path/to/model

Training

ImageNet

To train GFNet models on ImageNet from scratch, run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_gfnet.py  --output_dir logs/gfnet-xs --arch gfnet-xs --batch-size 128 --data-path /path/to/ILSVRC2012/

To finetune a pre-trained model at higher resolution, run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_gfnet.py  --output_dir logs/gfnet-xs-img384 --arch gfnet-xs --input-size 384 --batch-size 64 --data-path /path/to/ILSVRC2012/ --lr 5e-6 --weight-decay 1e-8 --min-lr 5e-6 --epochs 30 --finetune /path/to/model

Transfer Learning Datasets

To finetune a pre-trained model on a transfer learning dataset, run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_gfnet_transfer.py  --output_dir logs/gfnet-xs-cars --arch gfnet-xs --batch-size 64 --data-set CARS --data-path /path/to/stanford_cars --epochs 1000 --dist-eval --lr 0.0001 --weight-decay 1e-4 --clip-grad 1 --warmup-epochs 5 --finetune /path/to/model 

License

MIT License

Citation

If you find our work useful in your research, please consider citing:

@article{rao2021global,
  title={Global Filter Networks for Image Classification},
  author={Rao, Yongming and Zhao, Wenliang and Zhu, Zheng and Lu, Jiwen and Zhou, Jie},
  journal={arXiv preprint arXiv:2107.00645},
  year={2021}
}
Comments
  • Flexible input size

    Flexible input size

    Hi, I came across your work and thought it was a very interesting concept. Currently, the network takes in fixed input sizes. But is there a way for there to be flexible input sizes? I realize the main constraint here is the following line where the complex weight is defined during initialization time: self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02) Is there a way to modify this line so that we can have inputs of different sizes? Thanks

    opened by repers 8
  • Question about Complexity (FLOPs)

    Question about Complexity (FLOPs)

    Hi, interesting work! I wonder how the Complexity (FLOPs) for global filter in Table.1 is calculated. Since the conjugate symmetric for real signals, we have:

    case1: consider the conjugate symmetric. RFFT: HWD/2 * log2(HW) Global Filter: HWD/2 IRFFT: HWD/2 * log2(HW)

    Thus, the total Complexity (FLOPs) for global filter is: HWD * log2(HW) + HWD/2. Is it right?

    case2: not consider the conjugate symmetric. RFFT: HWD * log2(HW) Global Filter: HWD IRFFT: HWD * log2(HW)

    Thus, the total Complexity (FLOPs) for global filter is: 2HWD * log2(HW) + HWD. Is it right?

    Which is right?

    opened by techmonsterwang 5
  • The fp32fft option

    The fp32fft option

    Hello, thanks for your nice work!

    I wonder what does the option fp32fft do. In my experiments the input and output to the fft function are already torch.float32, so I'm not sure why there is an option for converting to fp32. Thanks in advance

    opened by liuzhuang13 5
  • Question about adversarial robustness

    Question about adversarial robustness

    Hi, Rao Thank you for your great work! When measuring GFNet's advertising robustness through FGSM and PGD, can I know specific conditions and hyperparameters? It would be even better if I could get the code you used!

    opened by JHLEE17 2
  • Memory and FLOPs concern?

    Memory and FLOPs concern?

    Hi! very interesting work!

    How is Params calculated? Do you use profile ?

    I have noticed that you use a script to calculate memory and flops. Can you share the script? Many thanks goes to the author~

    opened by techmonsterwang 2
  • About size h and w

    About size h and w

    Thank you for your excellent work.

    I'm curious why w is set to w = h // 2 + 1, or just the experiment proves that it is better to set w = h // 2 + 1 in this way

    opened by ksl1231 2
  • Question about block design

    Question about block design

    Hello, thanks for your great work!

    In your figure and code, there is no skip connection after the global filter layer.

    This is different from original transformer implementation, which has 2 skip connections in a single block (each for self-attention layer and FFN layer)

    For example, original transformer uses blocks like

    x = x + SA(x)
    x = x + FFN(x)
    

    But, global filter network uses below block

    x_ = Global_Filter(x)
    x = x + FFN(x_)
    

    Is there any reason for adopting the current block architecture?

    Thanks,

    opened by hsi1032 2
  • Question about 3D configuration

    Question about 3D configuration

    Hi, Thank you for your excellent work. I have some questions concerning 3D configuration of GF-Net. I extend your model to a 3D version by introducing 3D FFT and IFFT to conduct global filter learning, and test on the 3D data classification (Point Cloud / Volumetric data). However, the over-fitting problem occurs (Traning dataset 97% acc. Testing dataset 75% acc.). Could you provide some advice on how to train such a model. (I have tried Dropout with different ratios). Thank you~

    opened by jerryzhang1119 2
  • image size for ADE20K

    image size for ADE20K

    Hi, Yongming

    What is image size did you use for training and validation on ADE20K?

    I noticed that PVT used 512x512 for training and a different scale for testing. However, as the parameters of Global Filter are related to the image size, how do you deal with the scale change?

    Thanks in advance.

    opened by ShoufaChen 2
  • FLOPs Concern

    FLOPs Concern

    Hello, thanks for your nice work!

    I test the FLOPs of GFNet using the fvcore library as following:

        from fvcore.nn import flop_count
        model_mode = model.training
        model.eval()
        fake_input = torch.rand(1, 3, 224, 224)
        flops_dict, *_ = flop_count(model, fake_input)
        count = sum(flops_dict.values())
        model.train(model_mode)
        print("fvcore FLOPs: {:.3f} G".format(count))
    

    For gfnet-h-b model, I got the result: 8.547 G. It is slightly higher than what you mentioned in the paper, i.e., 8.4G.

    My concerns are:

    1. How do you get the FLOPs value?

    2. From, the fvcore log, I noticed that:

    Unsupported operator aten::fft_irfft2 encountered 36 time(s)
    

    I.e., the FLOPs of the fft_irfft2 operator are not taken into account.
    I wonder if you consider this operator when calculating the FLOPs?

    If not, I think it would be better to consider it because this is the core operator that replaces self-attention.

    Please let me know if I missed something.

    Thanks.

    opened by ShoufaChen 2
  • parameters

    parameters

    class GlobalFilter(nn.Module): def init(self, dim, h=14, w=8): super().init() self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)

    def forward(self, x):
        B, H, W, C = x.shape
        x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
        weight = torch.view_as_complex(self.complex_weight)
        x = x * weight
        x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
        return x
    

    Thank you very much for your work. I have some questions. What's meaning of "h=14, w=8", "s=(H, W), dim=(1, 2)".

    opened by 123456789-qwer 1
  • About visualization

    About visualization

    Thank you very much for your work, but when I realized the visualization of the frequency domain filter, the image is different from the one you provided. And I would like to ask how to realize the visualization of the spatial domain

    opened by DingDinmao 1
Simple-Image-Classification - Simple Image Classification Code (PyTorch)

Simple-Image-Classification Simple Image Classification Code (PyTorch) Yechan Kim This repository contains: Python3 / Pytorch code for multi-class ima

Yechan Kim 8 Oct 29, 2022
Image Classification - A research on image classification and auto insurance claim prediction, a systematic experiments on modeling techniques and approaches

A research on image classification and auto insurance claim prediction, a systematic experiments on modeling techniques and approaches

null 0 Jan 23, 2022
The official implementation of the CVPR2021 paper: Decoupled Dynamic Filter Networks

Decoupled Dynamic Filter Networks This repo is the official implementation of CVPR2021 paper: "Decoupled Dynamic Filter Networks". Introduction DDF is

F.S.Fire 180 Dec 30, 2022
Search and filter videos based on objects that appear in them using convolutional neural networks

Thingscoop: Utility for searching and filtering videos based on their content Description Thingscoop is a command-line utility for analyzing videos se

Anastasis Germanidis 354 Dec 4, 2022
PyTorch implementation of "ContextNet: Improving Convolutional Neural Networks for Automatic Speech Recognition with Global Context" (INTERSPEECH 2020)

ContextNet ContextNet has CNN-RNN-transducer architecture and features a fully convolutional encoder that incorporates global context information into

Sangchun Ha 24 Nov 24, 2022
[CVPR 2021] Official PyTorch Implementation for "Iterative Filter Adaptive Network for Single Image Defocus Deblurring"

IFAN: Iterative Filter Adaptive Network for Single Image Defocus Deblurring Checkout for the demo (GUI/Google Colab)! The GUI version might occasional

Junyong Lee 173 Dec 30, 2022
Vanilla and Prototypical Networks with Random Weights for image classification on Omniglot and mini-ImageNet. Made with Python3.

vanilla-rw-protonets-project Vanilla Prototypical Networks and PNs with Random Weights for image classification on Omniglot and mini-ImageNet. Made wi

Giovani Candido 8 Aug 31, 2022
Danfeng Hong, Lianru Gao, Jing Yao, Bing Zhang, Antonio Plaza, Jocelyn Chanussot. Graph Convolutional Networks for Hyperspectral Image Classification, IEEE TGRS, 2021.

Graph Convolutional Networks for Hyperspectral Image Classification Danfeng Hong, Lianru Gao, Jing Yao, Bing Zhang, Antonio Plaza, Jocelyn Chanussot T

Danfeng Hong 154 Dec 13, 2022
Semi-Supervised Graph Prototypical Networks for Hyperspectral Image Classification, IGARSS, 2021.

Semi-Supervised Graph Prototypical Networks for Hyperspectral Image Classification, IGARSS, 2021. Bobo Xi, Jiaojiao Li, Yunsong Li and Qian Du. Code f

Bobo Xi 7 Nov 3, 2022
Semi-supervised Representation Learning for Remote Sensing Image Classification Based on Generative Adversarial Networks

SSRL-for-image-classification Semi-supervised Representation Learning for Remote Sensing Image Classification Based on Generative Adversarial Networks

Feng 2 Nov 19, 2021
Implement face detection, and age and gender classification, and emotion classification.

YOLO Keras Face Detection Implement Face detection, and Age and Gender Classification, and Emotion Classification. (image from wider face dataset) Ove

Chloe 10 Nov 14, 2022
Hl classification bc - A Network-Based High-Level Data Classification Algorithm Using Betweenness Centrality

A Network-Based High-Level Data Classification Algorithm Using Betweenness Centr

Esteban Vilca 3 Dec 1, 2022
Implementation of Kalman Filter in Python

Kalman Filter in Python This is a basic example of how Kalman filter works in Python. I do plan on refactoring and expanding this repo in the future.

Enoch Kan 35 Sep 11, 2022
Adversarial Color Enhancement: Generating Unrestricted Adversarial Images by Optimizing a Color Filter

ACE Please find the preliminary version published at BMVC 2020 in the folder BMVC_version, and its extended journal version in Journal_version. Datase

null 28 Dec 25, 2022
A Rao-Blackwellized Particle Filter for 6D Object Pose Tracking

PoseRBPF: A Rao-Blackwellized Particle Filter for 6D Object Pose Tracking PoseRBPF Paper Self-supervision Paper Pose Estimation Video Robot Manipulati

NVIDIA Research Projects 107 Dec 25, 2022
Incorporating Transformer and LSTM to Kalman Filter with EM algorithm

Deep learning based state estimation: incorporating Transformer and LSTM to Kalman Filter with EM algorithm Overview Kalman Filter requires the true p

zshicode 57 Dec 27, 2022
PyTorch Implementation of NCSOFT's FastPitchFormant: Source-filter based Decomposed Modeling for Speech Synthesis

FastPitchFormant - PyTorch Implementation PyTorch Implementation of FastPitchFormant: Source-filter based Decomposed Modeling for Speech Synthesis. Qu

Keon Lee 63 Jan 2, 2023
Simple torch.nn.module implementation of Alias-Free-GAN style filter and resample

Alias-Free-Torch Simple torch module implementation of Alias-Free GAN. This repository including Alias-Free GAN style lowpass sinc filter @filter.py A

이준혁(Junhyeok Lee) 64 Dec 22, 2022
A simple implementation of Kalman filter in single object tracking

kalman-filter-in-single-object-tracking A simple implementation of Kalman filter in single object tracking https://www.bilibili.com/video/BV1Qf4y1J7D4

null 130 Dec 26, 2022