Global Filter Networks for Image Classification


Created by Yongming Rao, Wenliang Zhao, Zheng Zhu, Jiwen Lu, Jie Zhou

This repository contains PyTorch implementation for GFNet.

Global Filter Networks is a transformer-style architecture that learns long-term spatial dependencies in the frequency domain with log-linear complexity. Our architecture replaces the self-attention layer in vision transformers with three key operations: a 2D discrete Fourier transform, an element-wise multiplication between frequency-domain features and learnable global filters, and a 2D inverse Fourier transform.


Our code is based on pytorch-image-models and DeiT.

[Project Page] [arXiv]

Global Filter Layer

GFNet is a conceptually simple yet computationally efficient architecture, which consists of several stacking Global Filter Layers and Feedforward Networks (FFN). The Global Filter Layer mixes tokens with log-linear complexity benefiting from the highly efficient Fast Fourier Transform (FFT) algorithm. The layer is easy to implement:

import torch
import torch.nn as nn
import torch.fft

class GlobalFilter(nn.Module):
    def __init__(self, dim, h=14, w=8):
        self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
        self.w = w
        self.h = h

    def forward(self, x):
        B, H, W, C = x.shape
        x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
        weight = torch.view_as_complex(self.complex_weight)
        x = x * weight
        x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
        return x

Compared to self-attention and spatial MLP, our Global Filter Layer is much more efficient to process high-resolution feature maps:


Model Zoo

We provide our GFNet models pretrained on ImageNet:

name arch Params FLOPs acc@1 acc@5 url
GFNet-Ti gfnet-ti 7M 1.3G 74.6 92.2 Tsinghua Cloud / Google Drive
GFNet-XS gfnet-xs 16M 2.8G 78.6 94.2 Tsinghua Cloud / Google Drive
GFNet-S gfnet-s 25M 4.5G 80.0 94.9 Tsinghua Cloud / Google Drive
GFNet-B gfnet-b 43M 7.9G 80.7 95.1 Tsinghua Cloud / Google Drive
GFNet-H-Ti gfnet-h-ti 15M 2.0G 80.1 95.1 Tsinghua Cloud / Google Drive
GFNet-H-S gfnet-h-s 32M 4.5G 81.5 95.6 Tsinghua Cloud / Google Drive
GFNet-H-B gfnet-h-b 54M 8.4G 82.9 96.2 Tsinghua Cloud / Google Drive



  • torch>=1.8.1
  • torchvision
  • timm

Data preparation: download and extract ImageNet images from The directory structure should be

│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......


To evaluate a pre-trained GFNet model on the ImageNet validation set with a single GPU, run:

python --data-path /path/to/ILSVRC2012/ --arch arch_name --path /path/to/model



To train GFNet models on ImageNet from scratch, run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env  --output_dir logs/gfnet-xs --arch gfnet-xs --batch-size 128 --data-path /path/to/ILSVRC2012/

To finetune a pre-trained model at higher resolution, run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env  --output_dir logs/gfnet-xs-img384 --arch gfnet-xs --input-size 384 --batch-size 64 --data-path /path/to/ILSVRC2012/ --lr 5e-6 --weight-decay 1e-8 --min-lr 5e-6 --epochs 30 --finetune /path/to/model

Transfer Learning Datasets

To finetune a pre-trained model on a transfer learning dataset, run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env  --output_dir logs/gfnet-xs-cars --arch gfnet-xs --batch-size 64 --data-set CARS --data-path /path/to/stanford_cars --epochs 1000 --dist-eval --lr 0.0001 --weight-decay 1e-4 --clip-grad 1 --warmup-epochs 5 --finetune /path/to/model 


MIT License


If you find our work useful in your research, please consider citing:

  title={Global Filter Networks for Image Classification},
  author={Rao, Yongming and Zhao, Wenliang and Zhu, Zheng and Lu, Jiwen and Zhou, Jie},
  journal={arXiv preprint arXiv:2107.00645},
  • Flexible input size

    Flexible input size

    Hi, I came across your work and thought it was a very interesting concept. Currently, the network takes in fixed input sizes. But is there a way for there to be flexible input sizes? I realize the main constraint here is the following line where the complex weight is defined during initialization time: self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02) Is there a way to modify this line so that we can have inputs of different sizes? Thanks

    opened by repers 8
  • Question about Complexity (FLOPs)

    Question about Complexity (FLOPs)

    Hi, interesting work! I wonder how the Complexity (FLOPs) for global filter in Table.1 is calculated. Since the conjugate symmetric for real signals, we have:

    case1: consider the conjugate symmetric. RFFT: HWD/2 * log2(HW) Global Filter: HWD/2 IRFFT: HWD/2 * log2(HW)

    Thus, the total Complexity (FLOPs) for global filter is: HWD * log2(HW) + HWD/2. Is it right?

    case2: not consider the conjugate symmetric. RFFT: HWD * log2(HW) Global Filter: HWD IRFFT: HWD * log2(HW)

    Thus, the total Complexity (FLOPs) for global filter is: 2HWD * log2(HW) + HWD. Is it right?

    Which is right?

    opened by techmonsterwang 5
  • The fp32fft option

    The fp32fft option

    Hello, thanks for your nice work!

    I wonder what does the option fp32fft do. In my experiments the input and output to the fft function are already torch.float32, so I'm not sure why there is an option for converting to fp32. Thanks in advance

    opened by liuzhuang13 5
  • Question about adversarial robustness

    Question about adversarial robustness

    Hi, Rao Thank you for your great work! When measuring GFNet's advertising robustness through FGSM and PGD, can I know specific conditions and hyperparameters? It would be even better if I could get the code you used!

    opened by JHLEE17 2
  • Memory and FLOPs concern?

    Memory and FLOPs concern?

    Hi! very interesting work!

    How is Params calculated? Do you use profile ?

    I have noticed that you use a script to calculate memory and flops. Can you share the script? Many thanks goes to the author~

    opened by techmonsterwang 2
  • About size h and w

    About size h and w

    Thank you for your excellent work.

    I'm curious why w is set to w = h // 2 + 1, or just the experiment proves that it is better to set w = h // 2 + 1 in this way

    opened by ksl1231 2
  • Question about block design

    Question about block design

    Hello, thanks for your great work!

    In your figure and code, there is no skip connection after the global filter layer.

    This is different from original transformer implementation, which has 2 skip connections in a single block (each for self-attention layer and FFN layer)

    For example, original transformer uses blocks like

    x = x + SA(x)
    x = x + FFN(x)

    But, global filter network uses below block

    x_ = Global_Filter(x)
    x = x + FFN(x_)

    Is there any reason for adopting the current block architecture?


    opened by hsi1032 2
  • Question about 3D configuration

    Question about 3D configuration

    Hi, Thank you for your excellent work. I have some questions concerning 3D configuration of GF-Net. I extend your model to a 3D version by introducing 3D FFT and IFFT to conduct global filter learning, and test on the 3D data classification (Point Cloud / Volumetric data). However, the over-fitting problem occurs (Traning dataset 97% acc. Testing dataset 75% acc.). Could you provide some advice on how to train such a model. (I have tried Dropout with different ratios). Thank you~

    opened by jerryzhang1119 2
  • image size for ADE20K

    image size for ADE20K

    Hi, Yongming

    What is image size did you use for training and validation on ADE20K?

    I noticed that PVT used 512x512 for training and a different scale for testing. However, as the parameters of Global Filter are related to the image size, how do you deal with the scale change?

    Thanks in advance.

    opened by ShoufaChen 2
  • FLOPs Concern

    FLOPs Concern

    Hello, thanks for your nice work!

    I test the FLOPs of GFNet using the fvcore library as following:

        from fvcore.nn import flop_count
        model_mode =
        fake_input = torch.rand(1, 3, 224, 224)
        flops_dict, *_ = flop_count(model, fake_input)
        count = sum(flops_dict.values())
        print("fvcore FLOPs: {:.3f} G".format(count))

    For gfnet-h-b model, I got the result: 8.547 G. It is slightly higher than what you mentioned in the paper, i.e., 8.4G.

    My concerns are:

    1. How do you get the FLOPs value?

    2. From, the fvcore log, I noticed that:

    Unsupported operator aten::fft_irfft2 encountered 36 time(s)

    I.e., the FLOPs of the fft_irfft2 operator are not taken into account.
    I wonder if you consider this operator when calculating the FLOPs?

    If not, I think it would be better to consider it because this is the core operator that replaces self-attention.

    Please let me know if I missed something.


    opened by ShoufaChen 2
  • parameters


    class GlobalFilter(nn.Module): def init(self, dim, h=14, w=8): super().init() self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)

    def forward(self, x):
        B, H, W, C = x.shape
        x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
        weight = torch.view_as_complex(self.complex_weight)
        x = x * weight
        x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
        return x

    Thank you very much for your work. I have some questions. What's meaning of "h=14, w=8", "s=(H, W), dim=(1, 2)".

    opened by 123456789-qwer 1
  • About visualization

    About visualization

    Thank you very much for your work, but when I realized the visualization of the frequency domain filter, the image is different from the one you provided. And I would like to ask how to realize the visualization of the spatial domain

    opened by DingDinmao 1
