PyTorch code for our paper "Image Super-Resolution with Non-Local Sparse Attention" (CVPR2021).

Overview

Image Super-Resolution with Non-Local Sparse Attention

This repository is for NLSN introduced in the following paper "Image Super-Resolution with Non-Local Sparse Attention", CVPR2021, [Link]

The code is built on EDSR (PyTorch) and test on Ubuntu 18.04 environment (Python3.6, PyTorch >= 1.1.0) with V100 GPUs.

Contents

  1. Introduction
  2. Train
  3. Test
  4. Citation
  5. Acknowledgements

Introduction

Both Non-Local (NL) operation and sparse representa-tion are crucial for Single Image Super-Resolution (SISR).In this paper, we investigate their combinations and proposea novel Non-Local Sparse Attention (NLSA) with dynamicsparse attention pattern. NLSA is designed to retain long-range modeling capability from NL operation while enjoying robustness and high-efficiency of sparse representation.Specifically, NLSA rectifies non-local attention with spherical locality sensitive hashing (LSH) that partitions the input space into hash buckets of related features. For everyquery signal, NLSA assigns a bucket to it and only computes attention within the bucket. The resulting sparse attention prevents the model from attending to locations thatare noisy and less-informative, while reducing the computa-tional cost from quadratic to asymptotic linear with respectto the spatial size. Extensive experiments validate the effectiveness and efficiency of NLSA. With a few non-local sparseattention modules, our architecture, called non-local sparsenetwork (NLSN), reaches state-of-the-art performance forSISR quantitatively and qualitatively.

Non-Local Sparse Attention

Non-Local Sparse Attention.

NLSN

Non-Local Sparse Network.

Train

Prepare training data

  1. Download DIV2K training data (800 training + 100 validtion images) from DIV2K dataset or SNU_CVLab.

  2. Specify '--dir_data' based on the HR and LR images path.

For more informaiton, please refer to EDSR(PyTorch).

Begin to train

  1. (optional) Download pretrained models for our paper.

    Pre-trained models can be downloaded from Google Drive

  2. Cd to 'src', run the following script to train models.

    Example command is in the file 'demo.sh'.

    # Example X2 SR
    python main.py --dir_data ../../ --n_GPUs 4 --rgb_range 1 --chunk_size 144 --n_hashes 4 --save_models --lr 1e-4 --decay 200-400-600-800 --epochs 1000 --chop --save_results --n_resblocks 32 --n_feats 256 --res_scale 0.1 --batch_size 16 --model NLSN --scale 2 --patch_size 96 --save NLSN_x2 --data_train DIV2K
    

Test

Quick start

  1. Download benchmark datasets from SNU_CVLab

  2. (optional) Download pretrained models for our paper.

    All the models can be downloaded from Google Drive

  3. Cd to 'src', run the following scripts.

    Example command is in the file 'demo.sh'.

    # No self-ensemble: NLSN
    # Example X2 SR
    python main.py --dir_data ../../ --model NLSN  --chunk_size 144 --data_test Set5+Set14+B100+Urban100 --n_hashes 4 --chop --save_results --rgb_range 1 --data_range 801-900 --scale 2 --n_feats 256 --n_resblocks 32 --res_scale 0.1  --pre_train model_x2.pt --test_only

Citation

If you find the code helpful in your resarch or work, please cite the following papers.

@InProceedings{Mei_2021_CVPR,
    author    = {Mei, Yiqun and Fan, Yuchen and Zhou, Yuqian},
    title     = {Image Super-Resolution With Non-Local Sparse Attention},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2021},
    pages     = {3517-3526}
}
@InProceedings{Lim_2017_CVPR_Workshops,
  author = {Lim, Bee and Son, Sanghyun and Kim, Heewon and Nah, Seungjun and Lee, Kyoung Mu},
  title = {Enhanced Deep Residual Networks for Single Image Super-Resolution},
  booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
  month = {July},
  year = {2017}
}

Acknowledgements

This code is built on EDSR (PyTorch) and reformer-pytorch. We thank the authors for sharing their codes.

Issues
  • how to test?

    how to test?

    I have got the pre-trained models, and i followed the test commend in demo.sh. But i got the following answer。 D:\softwarezijianzhuangde\anaconda\envs\pytorch-1.9\python.exe "D:/data/experiments code/code/1/Non-Local-Sparse-Attention/src/main.py" --dir_data ../benchmarkdata/benchmark/benchmark --model NLSN --chunk_size 144 --data_test Set5+Set14+B100+Urban100 --n_hashes 4 --chop --save_results --rgb_range 1 --data_range 801-900 --scale 2 --n_feats 256 --n_resblocks 32 --res_scale 0.1 --pre_train ../experiment/test/model/model_x2.pt --test_only Making model... Loading model from ../experiment/test/model/model_x2.pt Total params: 41.80M

    Evaluation: [Set5 x2] PSNR: nan (Best: nan @epoch 1) [Set14 x2] PSNR: nan (Best: nan @epoch 1) [B100 x2] PSNR: nan (Best: nan @epoch 1) [Urban100 x2] PSNR: nan (Best: nan @epoch 1) Forward: 0.00s

    Saving... Total: 0.00s

    0it [00:00, ?it/s] 0it [00:00, ?it/s] 0it [00:00, ?it/s] 0it [00:00, ?it/s]

    Process finished with exit code 0 please help me.

    opened by cheun726 11
  • 训练不了

    训练不了

    RuntimeError: CUDA out of memory. Tried to allocate 348.00 MiB (GPU 1; 31.72 GiB total capacity; 1008.44 MiB already allocated; 349.62 MiB free; 59.56 MiB cached)

    opened by XiaoZhang-NN 7
  • Some problem about testing

    Some problem about testing

    Hi I changed all parameters to the default parameters. When running main.py, there are some problems:

    Making model... Loading model from ../experiment/test/model/model_x2.pt Total params: 44.16M Evaluation: 0it [00:00, ?it/s] [Set5 x4] PSNR: nan (Best: nan @epoch 1) 0it [00:00, ?it/s] [Set14 x4] PSNR: nan (Best: nan @epoch 1) 0it [00:00, ?it/s] [B100 x4] PSNR: nan (Best: nan @epoch 1) 0it [00:00, ?it/s] [Urban100 x4] PSNR: nan (Best: nan @epoch 1)

    How to solve this problem? Thank you.

    opened by chenquan-hdu 6
  • Reason behind using optimizer.get_lr() over optimizer.get_last_lr()?

    Reason behind using optimizer.get_lr() over optimizer.get_last_lr()?

    Hi there,

    I have been able to reproduce both your experiments and have read your paper as well and I do not understand why you do not use the last learning rate from the previous epoch, ie through get_last_lr(), in the start of the train step, but instead you opt for get_lr(), which returns a value that is only scaled by some gamma factor (implementation of this function is here for your reference: https://github.com/pytorch/pytorch/blob/fde94e75568b527b424b108c272793e096e8e471/torch/optim/lr_scheduler.py#L344-L352).

    Associated Pytorch warning is the following: lib/python3.8/site-packages/torch/optim/lr_scheduler.py:416: UserWarning: To get the last learning rate computed by the scheduler, please use get_last_lr(). warnings.warn("To get the last learning rate computed by the scheduler, "

    Looking forward to hearing back from you soon and all the best,

    Parsa Riahi

    opened by Priahi 4
  • 关于代码实现 bucket_score 变量的细节疑惑?

    关于代码实现 bucket_score 变量的细节疑惑?

    对照着论文描述和作图,我一步步仔细调试了您的代码,您的代码写的非常好! 我这里有个疑问,就是关于 bucket_score 变量(如下 我贴了您的代码),它是求得的不同bucket之间的相关性权重,并在softmax归一化后用score表示,用于了与y_att_buckets矩阵相乘,这一步我很明白。 # unormalized attention score raw_score = torch.einsum('bhkie,bhkje->bhkij', x_att_buckets, x_match) # [N, n_hashes, num_chunks, chunk_size, chunk_size*3]

        # softmax
        bucket_score = torch.logsumexp(raw_score, dim=-1, keepdim=True)
        score = torch.exp(raw_score - bucket_score)  # (after softmax)
        bucket_score = torch.reshape(bucket_score, [N, self.n_hashes, -1])
        
        # attention
        ret = torch.einsum('bukij,bukje->bukie', score, y_att_buckets)  # [N, n_hashes, num_chunks, chunk_size, C]
        ret = torch.reshape(ret, (N, self.n_hashes, -1, C*self.reduction))
    

    我主要不明白的是后续的代码,以上求得的ret是multi-round的,需要将multi-round这一维融合起来才能得到最终输出NCHW尺寸的特征,我不太明白后续为什么要用bucket_score进行softmax归一化后加权求和呢?这个bucket_score是 “不同bucket之间的相关性权重”,这里再用来求解multi-round维度的加权求和(如下 我贴了您的代码),总感觉怪怪的。 # recover the original order ret = torch.reshape(ret, (N, -1, Cself.reduction)) # [N, n_hashesHW,C] bucket_score = torch.reshape(bucket_score, (N, -1,)) # [N,n_hashesHW] ret = batched_index_select(ret, undo_sort) # [N, n_hashesHW,C] bucket_score = bucket_score.gather(1, undo_sort) # [N,n_hashesHW]

        # weighted sum multi-round attention
        ret = torch.reshape(ret, (N, self.n_hashes, L, C*self.reduction))  # [N, n_hashes*H*W,C]
        bucket_score = torch.reshape(bucket_score, (N, self.n_hashes, L, 1))
        probs = nn.functional.softmax(bucket_score, dim=1)
        ret = torch.sum(ret * probs, dim=1)
    

    我个人觉得是,这里的multi-round,其实一定程度上是类似于Transformer中multi-head的,仿照它的操作,直接将multi-round维和channel维合并为multi-roundchannel,再用11 Conv映射到channel是不是应该更合理呢?

    opened by Deep-imagelab 3
  • Computational complexity

    Computational complexity

    I have questions about Computational complexity to ask you 。 1.input feature X ∈Rn×c Does this n and c refer to the length and width of the input image? 2.The sorting operation of a sequence with length n and m distinct numbers (bucket number) adds an additional O(nm) with quick sort 。What is the length n in this sentence? thank you @HarukiYqM

    opened by cheun726 2
  • Computational complexity

    Computational complexity

    老师您好,我有一个关于NLSA的计算复杂度的问题想请教您一下。 1.input feature X ∈Rn×c,这个n,c指的是输入图片的长和宽吗? 2.The sorting operation of a sequence with length n and m distinct numbers (bucket number) adds an additional O(nm) with quick sort 。这句话里的length n指的是什么呢

    opened by cheun726 2
  • 在Urban100上的测试结果偏差很大

    在Urban100上的测试结果偏差很大

    由于显存限制,batch_size训练时设为8,以下是epoch=158的测试结果: Evaluation: 100%|██████████| 5/5 [00:04<00:00, 1.04it/s] [Set5 x2] PSNR: 38.051 (Best: 38.051 @epoch 1) 100%|██████████| 14/14 [00:18<00:00, 1.30s/it] [Set14 x2] PSNR: 33.845 (Best: 33.845 @epoch 1) 100%|██████████| 100/100 [01:33<00:00, 1.07it/s] [B100 x2] PSNR: 32.246 (Best: 32.246 @epoch 1) 100%|██████████| 100/100 [06:51<00:00, 4.11s/it] [Urban100 x2] PSNR: 32.461 (Best: 32.461 @epoch 1)

    Set5 、Set14、B100测试结果符合预期,但在Urban100上比论文结果(33.42)低了接近1DB。训练参数除了batch_size设为8其余保持默认。为什么在Urban100上会突然降低这么多,请问是什么原因呢?

    opened by laoyangui 2
  • Issue about the args.test_every

    Issue about the args.test_every

    I noticed that the args.test_every in the code is used to indirectly control the number of times the dataset is reused in each epoch (the repeat in SRData class). It affects the total number of iterations. In order to reproduce the result of your paper, I want to know the value set for args.test_every in your experiments. Thank you again and look forward to your reply.

    opened by workingcoder 2
  • Issue about the Evaluation Metrics

    Issue about the Evaluation Metrics

    First of all, thank the author for providing the code. I noticed that PSNR and SSIM were used as metrics in the experiments of the paper. However, only the calculation function of PSNR is offered in the code (calc_psnr func in utility.py). Calculating SSIM involves setting some parameters, which are not explained in detail in the paper or this git repository. In order to be consistent with the calculation process of your paper, can you provide the function to calculate SSIM (the cal_ssim function)? Thanks for your kind attention and look forward your prompt reply.

    opened by workingcoder 2
  • The number of NLSN operations?

    The number of NLSN operations?

    Hi, I set res_block=32, so the model should contain 5 NLSA modules. When I send an image to the model, it should go through 5 NLSA runs, but when I counted the NLSA runs, I found that the NLSA runs 20 times. Why?
    For each image, run the model 4 times to calculate the average PSNR?

    Thank you.

    opened by chenquan-hdu 2
  • Help with codes

    Help with codes

    Hi! Thank you for sharing this interesting work. Would you share the implementation of the non-local block using the local window? (i.e. the local window strategy mentioned in the ablation study)

    opened by nounotabe 0
Owner
Mei Yiqun, Previously @ UIUC
null
Official implementation of our CVPR2021 paper "OTA: Optimal Transport Assignment for Object Detection" in Pytorch.

OTA: Optimal Transport Assignment for Object Detection This project provides an implementation for our CVPR2021 paper "OTA: Optimal Transport Assignme

null 202 Jul 28, 2022
Code for our CVPR2021 paper coordinate attention

Coordinate Attention for Efficient Mobile Network Design (preprint) This repository is a PyTorch implementation of our coordinate attention (will appe

Qibin (Andrew) Hou 666 Aug 13, 2022
[CVPR2021] The source code for our paper 《Removing the Background by Adding the Background: Towards Background Robust Self-supervised Video Representation Learning》.

TBE The source code for our paper "Removing the Background by Adding the Background: Towards Background Robust Self-supervised Video Representation Le

Jinpeng Wang 146 Jul 12, 2022
A weakly-supervised scene graph generation codebase. The implementation of our CVPR2021 paper ``Linguistic Structures as Weak Supervision for Visual Scene Graph Generation''

README.md shall be finished soon. WSSGG 0 Overview 1 Installation 1.1 Faster-RCNN 1.2 Language Parser 1.3 GloVe Embeddings 2 Settings 2.1 VG-GT-Graph

Keren Ye 31 Aug 15, 2022
PyTorch code for the paper "Curriculum Graph Co-Teaching for Multi-target Domain Adaptation" (CVPR2021)

PyTorch code for the paper "Curriculum Graph Co-Teaching for Multi-target Domain Adaptation" (CVPR2021) This repo presents PyTorch implementation of M

Evgeny 75 Aug 8, 2022
offical implement of our Lifelong Person Re-Identification via Adaptive Knowledge Accumulation in CVPR2021

LifelongReID Offical implementation of our Lifelong Person Re-Identification via Adaptive Knowledge Accumulation in CVPR2021 by Nan Pu, Wei Chen, Yu L

PeterPu 71 Jun 25, 2022
[PyTorch] Official implementation of CVPR2021 paper "PointDSC: Robust Point Cloud Registration using Deep Spatial Consistency". https://arxiv.org/abs/2103.05465

PointDSC repository PyTorch implementation of PointDSC for CVPR'2021 paper "PointDSC: Robust Point Cloud Registration using Deep Spatial Consistency",

null 132 Jul 29, 2022
Pytorch implementation of CVPR2021 paper "MUST-GAN: Multi-level Statistics Transfer for Self-driven Person Image Generation"

MUST-GAN Code | paper The Pytorch implementation of our CVPR2021 paper "MUST-GAN: Multi-level Statistics Transfer for Self-driven Person Image Generat

TianxiangMa 42 Jul 13, 2022
A pytorch implementation of the CVPR2021 paper "VSPW: A Large-scale Dataset for Video Scene Parsing in the Wild"

VSPW: A Large-scale Dataset for Video Scene Parsing in the Wild A pytorch implementation of the CVPR2021 paper "VSPW: A Large-scale Dataset for Video

null 41 Jul 24, 2022
Code for CVPR2021 paper "Robust Reflection Removal with Reflection-free Flash-only Cues"

Robust Reflection Removal with Reflection-free Flash-only Cues (RFC) Paper | To be released: Project Page | Video | Data Tensorflow implementation for

Chenyang LEI 143 Jul 26, 2022
Code for the paper "Graph Attention Tracking". (CVPR2021)

SiamGAT 1. Environment setup This code has been tested on Ubuntu 16.04, Python 3.5, Pytorch 1.2.0, CUDA 9.0. Please install related libraries before r

null 115 Jul 19, 2022
Code for CVPR2021 paper "Learning Salient Boundary Feature for Anchor-free Temporal Action Localization"

AFSD: Learning Salient Boundary Feature for Anchor-free Temporal Action Localization This is an official implementation in PyTorch of AFSD. Our paper

Tencent YouTu Research 129 Aug 11, 2022
Code for C2-Matching (CVPR2021). Paper: Robust Reference-based Super-Resolution via C2-Matching.

C2-Matching (CVPR2021) This repository contains the implementation of the following paper: Robust Reference-based Super-Resolution via C2-Matching Yum

Yuming Jiang 137 Aug 5, 2022
Code for CVPR2021 paper 'Where and What? Examining Interpretable Disentangled Representations'.

PS-SC GAN This repository contains the main code for training a PS-SC GAN (a GAN implemented with the Perceptual Simplicity and Spatial Constriction c

Xinqi/Steven Zhu 39 May 25, 2022
Code for the CVPR2021 paper "Patch-NetVLAD: Multi-Scale Fusion of Locally-Global Descriptors for Place Recognition"

Patch-NetVLAD: Multi-Scale Fusion of Locally-Global Descriptors for Place Recognition This repository contains code for the CVPR2021 paper "Patch-NetV

QVPR 326 Aug 11, 2022
Official code of paper "PGT: A Progressive Method for Training Models on Long Videos" on CVPR2021

PGT Code for paper PGT: A Progressive Method for Training Models on Long Videos. Install Run pip install -r requirements.txt. Run python setup.py buil

Bo Pang 27 Mar 30, 2022
The official PyTorch code for 'DER: Dynamically Expandable Representation for Class Incremental Learning' accepted by CVPR2021

DER.ClassIL.Pytorch This repo is the official implementation of DER: Dynamically Expandable Representation for Class Incremental Learning (CVPR 2021)

rhyssiyan 82 Aug 9, 2022
Repo for CVPR2021 paper "QPIC: Query-Based Pairwise Human-Object Interaction Detection with Image-Wide Contextual Information"

QPIC: Query-Based Pairwise Human-Object Interaction Detection with Image-Wide Contextual Information by Masato Tamura, Hiroki Ohashi, and Tomoaki Yosh

null 89 Jun 29, 2022
The implementation of the CVPR2021 paper "Structure-Aware Face Clustering on a Large-Scale Graph with 10^7 Nodes"

STAR-FC This code is the implementation for the CVPR 2021 paper "Structure-Aware Face Clustering on a Large-Scale Graph with 10^7 Nodes" ?? ?? . ?? Re

Shuai Shen 83 Jul 26, 2022