Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding (CVPR2022)

Overview

Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding

by Qiaole Dong*, Chenjie Cao*, Yanwei Fu

Paper and Supplemental Material (arXiv)

LICENSE

Pipeline

Click to expand

The overview of our ZITS. At first, the TSR model is used to restore structures with low resolutions. Then the simple CNN based upsampler is leveraged to upsample edge and line maps. Moreover, the upsampled sketch space is encoded and added to the FTR through ZeroRA to restore the textures.

TO DO

We have updated weights of TSR!

Our project page is available at https://dqiaole.github.io/ZITS_inpainting/.

  • Releasing inference codes.
  • Releasing pre-trained moodel.
  • Releasing training codes.

Preparation

Click to expand
  1. Preparing the environment:

    as there are some bugs when using GP loss with DDP (link), we strongly recommend installing Apex without CUDA extensions via torch1.9.0 for the multi-gpu training

    conda create -n train_env python=3.6
    conda activate train_env
    pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
    pip install -r requirement.txt
    git clone https://github.com/NVIDIA/apex
    cd apex
    pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" ./
    
  2. For training, MST provide irregular and segmentation masks (download) with different masking rates. And you should define the mask file list before the training as in MST.

  3. Download the pretrained masked wireframe detection model to the './ckpt' fold: LSM-HAWP (MST ICCV2021 retrained from HAWP CVPR2020).

  4. Prepare the wireframes:

    as the MST train the LSM-HAWP in Pytorch 1.3.1 and it causes problem (link) when tested in Pytorch 1.9, we recommand to inference the lines(wireframes) with torch==1.3.1. If the line detection is not based on torch1.3.1, the performance may drop a little.

    conda create -n wireframes_inference_env python=3.6
    conda activate wireframes_inference_env
    pip install torch==1.3.1 torchvision==0.4.2
    pip install -r requirement.txt
    

    then extract wireframes with following code

    python lsm_hawp_inference.py --ckpt_path <best_lsm_hawp.pth> --input_path <input image path> --output_path <output image path> --gpu_ids '0'
    
  5. If you need to train the model, please download the pretrained models for perceptual loss, provided by LaMa:

    mkdir -p ade20k/ade20k-resnet50dilated-ppm_deepsup/
    wget -P ade20k/ade20k-resnet50dilated-ppm_deepsup/ http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth
    

Eval

Click to expand

Download pretrained models on Places2 here.

Link for BaiduDrive, password:qnm5

Batch Test

For batch test, you need to complete steps 3 and 4 above.

Put the pretrained models to the './ckpt' fold. Then modify the config file according to you image, mask and wireframes path.

Test on 256 images:

conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2 --config_file ./config_list/config_ZITS_places2.yml --GPU_ids '0'

Test on 512 images:

conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2_hr --config_file ./config_list/config_ZITS_HR_places2.yml --GPU_ids '0'

Single Image Test

Note: For single image test, environment 'wireframes_inference_env' in step 4 is recommended for a better line detection. This code only supports squared images (or they will be center cropped).

conda activate wireframes_inference_env
python single_image_test.py --path <ckpt_path> --config_file <config_path> \
 --GPU_ids '0' --img_path ./image.png --mask_path ./mask.png --save_path ./

Training

Click to expand

⚠️ Warning: The training codes is not fully tested yet after refactoring

Training TSR

python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
 --train_line_path [training_wireframes_path] \
 --mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
 --train_epoch 12 --validation_path [validation_data_path] \
 --val_line_path [validation_wireframes_path] \
 --valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP
python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
 --train_line_path [training_wireframes_path] \
 --mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
 --train_epoch 15 --validation_path [validation_data_path] \
 --val_line_path [validation_wireframes_path] \
 --valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP --MaP

Train SSU

We recommend to use the pretrained SSU. You can also train your SSU refered to https://github.com/ewrfcas/StructureUpsampling.

Training LaMa First

python FTR_train.py --nodes 1 --gpus 1 --GPU_ids '0' --path ./ckpt/lama_places2 \
--config_file ./config_list/config_LAMA.yml --lama

Training FTR

256:

python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2 \
--config_file ./config_list/config_ZITS_places2.yml --DDP

256~512:

python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2_HR \
--config_file ./config_list/config_ZITS_HR_places2.yml --DDP

More 1K Results

Click to expand

Acknowledgments

Cite

If you found our program helpful, please consider citing:

@inproceedings{dong2022incremental,
      title={Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding}, 
      author={Qiaole Dong and Chenjie Cao and Yanwei Fu},
      booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
      year={2022}
}
Issues
  • 可否预上传一份pth样本,直接调试

    可否预上传一份pth样本,直接调试

    config_ZITS_places2.yml

    transformer_ckpt_path: './ckpt/best_transformer_places2.pth' gen_weights_path0: './ckpt/lama_places2/InpaintingModel_gen.pth' # Not required at the time of eval dis_weights_path0: './ckpt/lama_places2/InpaintingModel_dis.pth' # Not required at the time of eval structure_upsample_path: './ckpt/StructureUpsampling.pth'

    D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\src\models\FTR_model.py

    data = torch.load(config.structure_upsample_path, map_location='cpu')
    

    发生异常: AttributeError 'NoneType' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

    During handling of the above exception, another exception occurred:

    File "D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\src\models\FTR_model.py", line 165, in init data = torch.load(config.structure_upsample_path, map_location='cpu') File "D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\src\models\FTR_model.py", line 427, in init super().init(*args, gpu=gpu, name='InpaintingModel', rank=rank, test=test, **kwargs) File "D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\src\FTR_trainer.py", line 256, in init self.inpaint_model = DefaultInpaintingTrainingModule(config, gpu=gpu, rank=rank, test=test, **kwargs).to(gpu) File "D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\single_image_test.py", line 323, in model = ZITS(config, 0, 0, True)


    PS D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main> & 'D:\pm\python\python38\python.exe' 'c:\Users\Administrator.vscode\extensions\ms-python.python-2022.4.1\pythonFiles\lib\python\debugpy\launcher' '40191' '--' 'd:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\single_image_test.py' '--path=D:\pm\python\lama\LaMa_models\lama-places\lama-fourier\models' '--config_file=D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\config_list\config_ZITS_places2.yml' '--GPU_ids=-1' '--img_path=D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\imgs\y\i1.png' '--mask_path=D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\imgs\mask\i1.png' '--save_path=D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\imgs' Backend TkAgg is interactive backend. Turning interactive mode on. BaseInpaintingTrainingModule init called Loading InpaintingModel StructureUpsampling...

    opened by time888 4
  • ERROR: Could not find a version that satisfies the requirement torch==1.3.1

    ERROR: Could not find a version that satisfies the requirement torch==1.3.1

    Hi, You recommend to inference the wireframes with torch 1.3.1 on README, but could not find the version by pip.

    ERROR: Could not find a version that satisfies the requirement torch==1.3.1 (from versions: 1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2, 1.11.0)
    ERROR: No matching distribution found for torch==1.3.1
    

    How did you install the old version?

    opened by naoki7090624 4
  • Question about loss and activation function

    Question about loss and activation function

    Hi, I have questions about activation function and loss.

    1. Why you calculate the loss before the activation function? According to your code, the cross entropy loss is calculated before sigmoid function. In the general CNN, I think the loss is calculated after the activation. Could you tell me why.

    2. Why you use only cross entropy loss? According to your code, only the cross entropy loss is calculated in the TSR. I wonder if you could use other losses (L1 loss, feature matching loss) after upsampling because there are convolution layers after transformer blocks.

    opened by naoki7090624 4
  • Access to the pre-trained model

    Access to the pre-trained model

    Loved the paper! The results compare to LaMa are amazing. Can I have access to the lightest pre-trained model? (Benchmarking on mobile devices)

    Best regards, Roi

    opened by roimulia2 4
  • How to make the Umsampling model?

    How to make the Umsampling model?

    Thank you for sharing your great work!!

    I try to train the network from scratch. TSR_train.py created best.pth, latest.pth and log.txt in the ckpt directory. After that, I try to run the FTR_train.py, but there is a error. FileNotFoundError: [Errno 2] No such file or directory: './ckpt/StructureUpsampling.pth'

    How to make the Upsampling model?

    opened by naoki7090624 3
  • Question about image size

    Question about image size

    Thank you for sharing your great works.

    I am planning to integrate your pre-train model into lama-cleaner. However, I noticed that the single_image_test.py script only supports square images, is this a limitation in the network structure? Or is it just a limitation of this test script?

    截屏2022-03-28 下午5 54 42
    opened by Sanster 3
  • FileNotFoundError

    FileNotFoundError

    Hello,

    thanks a lot for your great work!

    When I run the single image test:

    conda activate wireframes_inference_env
    python single_image_test.py --path <ckpt_path> --config_file <config_path> \
     --GPU_ids '0' --img_path ./image.png --mask_path ./mask.png --save_path ./
    

    I get an error FileNotFoundError: [Errno 2] No such file or directory: '/home/wmlce/places365_standard/places2_all/train_list.txt'

    It is triggered when single_image_test.py is executing model = ZITS(config, 0, 0, True). This path is defined in the config files. Can the script be run without this file?

    Thank you!

    opened by yatsurama 2
  • About Transformer Block

    About Transformer Block

    Thank you for sharing your great works!!

    I have two questions.

    1. Could you explain your transformer block in Fig.2 of your paper? According to your code, the transformer block consists of the following layers. AxialAttention -> my_Block_2(CausalSelfAttention + MLP) I think the first feedforward refers to CausalSelfAttention and Vanilla Attention refers to MLP, but what is the last feedforward?

    2. What is the differences between your transformer block and ICT transformer block?

    Thank you in advance.

    opened by naoki7090624 2
  • Question about Wireframe extraction difference Single vs Batch mode

    Question about Wireframe extraction difference Single vs Batch mode

    It seems for Single Image Test, the wireframe extraction is done for masked images. https://github.com/DQiaole/ZITS_inpainting/blob/main/single_image_test.py#L173 (Masking) https://github.com/DQiaole/ZITS_inpainting/blob/main/single_image_test.py#L194 (Wireframe inference) https://github.com/DQiaole/ZITS_inpainting/blob/main/single_image_test.py#L219 (obj_remove False, so use lines_masked and scores_masked). Side Q. When should obj_remove be used? (The code also calculates wireframes for original image but it is not used if obj_remove is false).

    But, for Batch test, the wireframe extraction is done on original images (it is recommended to precompute the wireframes). https://github.com/DQiaole/ZITS_inpainting#batch-test (precompute wireframes) Then, the image, edge and line is masked before passing through transformer. https://github.com/DQiaole/ZITS_inpainting/blob/main/src/utils.py#L273

    Q. I am not an expert in wireframe extraction. But, wouldn't passing a masked image for wireframe extraction vs passing the full image and then masking give different results? Or is it the same and it doesn't matter? Ideally, for inpainting, we wouldn't have access to the original unmasked images and cannot extract wireframes on them. Why this difference in implementation?

    opened by ani0075 1
  • Questions/Issues about batch test evaluation.

    Questions/Issues about batch test evaluation.

    Mask Image Correspondence If I am batch evaluating on a custom dataset, do I need to have the same number of masks as the number of evaluation images? How does the code decide correspondence between an image and mask? Is there any naming convention? I have generated the wireframes for all evaluation images. I have 1200 images. But the provided test masks are only 1000. I am wondering if the batch evaluation will work.

    Possible bug https://github.com/DQiaole/ZITS_inpainting/blob/main/src/FTR_trainer.py#L271 If test is True (batch eval), the self.val_dataset is never created.

    Wouldn't https://github.com/DQiaole/ZITS_inpainting/blob/main/src/FTR_trainer.py#L443 throw an error?

    opened by ani0075 1
  • single_image_test

    single_image_test

    我用Visual Studio Code 调试single_image_test.py 在 config_path = os.path.join(args.path, 'config.yml')

    发生异常: TypeError expected str, bytes or os.PathLike object, not NoneType File "D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\single_image_test.py", line 289, in config_path = os.path.join(args.path, 'config.yml')

    opened by time888 1
  • Provide discriminator weights

    Provide discriminator weights

    The config file contains this line: dis_weights_path0: './ckpt/lama_places2/InpaintingModel_dis.pth'

    Can you please provide these weights? Or point me to where I can get them?

    opened by ani0075 1
  • Finetuning on custom dataset

    Finetuning on custom dataset

    Hi,

    I used your pre-trained models on my custom dataset and the inpainting results were not great. I was thinking of finetuning the pretrained weights on my custom dataset to improve inpainting quality.

    What should be the steps/commands? Which models do I need to finetune? TSR and FTR?

    It would be great if you could provide some suggestions.

    opened by ani0075 1
Owner
Qiaole Dong
Qiaole Dong
PyTorch implementation of Rethinking Positional Encoding in Language Pre-training

TUPE PyTorch implementation of Rethinking Positional Encoding in Language Pre-training. Quickstart Clone this repository. git clone https://github.com

Jake Tae 5 Jan 27, 2022
AOT-GAN for High-Resolution Image Inpainting (codebase for image inpainting)

AOT-GAN for High-Resolution Image Inpainting Arxiv Paper | AOT-GAN: Aggregated Contextual Transformations for High-Resolution Image Inpainting Yanhong

Multimedia Research 168 Jun 21, 2022
My implementation of Image Inpainting - A deep learning Inpainting model

Image Inpainting What is Image Inpainting Image inpainting is a restorative process that allows for the fixing or removal of unwanted parts within ima

Joshua V Evans 1 Dec 12, 2021
Official code for "Towards An End-to-End Framework for Flow-Guided Video Inpainting" (CVPR2022)

E2FGVI (CVPR 2022) English | 简体中文 This repository contains the official implementation of the following paper: Towards An End-to-End Framework for Flo

Media Computing Group @ Nankai University 350 Jun 20, 2022
[ACMMM 2021 Oral] Enhanced Invertible Encoding for Learned Image Compression

InvCompress Official Pytorch Implementation for "Enhanced Invertible Encoding for Learned Image Compression", ACMMM 2021 (Oral) Figure: Our framework

null 74 Jun 15, 2022
CVPR 2021: "Generating Diverse Structure for Image Inpainting With Hierarchical VQ-VAE"

Diverse Structure Inpainting ArXiv | Papar | Supplementary Material | BibTex This repository is for the CVPR 2021 paper, "Generating Diverse Structure

null 139 Jun 26, 2022
[ICCV'2021] Image Inpainting via Conditional Texture and Structure Dual Generation

[ICCV'2021] Image Inpainting via Conditional Texture and Structure Dual Generation

Xiefan Guo 102 Jun 21, 2022
U-Net Implementation: Convolutional Networks for Biomedical Image Segmentation" using the Carvana Image Masking Dataset in PyTorch

U-Net Implementation By Christopher Ley This is my interpretation and implementation of the famous paper "U-Net: Convolutional Networks for Biomedical

Christopher Ley 1 Jan 6, 2022
MAT: Mask-Aware Transformer for Large Hole Image Inpainting

MAT: Mask-Aware Transformer for Large Hole Image Inpainting (CVPR2022, Oral) Wenbo Li, Zhe Lin, Kun Zhou, Lu Qi, Yi Wang, Jiaya Jia [Paper] News This

null 118 Jun 25, 2022
Multi-Scale Vision Longformer: A New Vision Transformer for High-Resolution Image Encoding

Vision Longformer This project provides the source code for the vision longformer paper. Multi-Scale Vision Longformer: A New Vision Transformer for H

Microsoft 186 Jun 20, 2022
Video Frame Interpolation with Transformer (CVPR2022)

VFIformer Official PyTorch implementation of our CVPR2022 paper Video Frame Interpolation with Transformer Dependencies python >= 3.8 pytorch >= 1.8.0

DV Lab 35 Jun 21, 2022
TopFormer: Token Pyramid Transformer for Mobile Semantic Segmentation, CVPR2022

TopFormer: Token Pyramid Transformer for Mobile Semantic Segmentation Paper Links: TopFormer: Token Pyramid Transformer for Mobile Semantic Segmentati

Hust Visual Learning Team 196 Jun 27, 2022
Implementation of Convolutional enhanced image Transformer

CeiT : Convolutional enhanced image Transformer This is an unofficial PyTorch implementation of Incorporating Convolution Designs into Visual Transfor

Rishikesh (ऋषिकेश) 73 Jun 27, 2022
This is an official implementation of the CVPR2022 paper "Blind2Unblind: Self-Supervised Image Denoising with Visible Blind Spots".

Blind2Unblind: Self-Supervised Image Denoising with Visible Blind Spots Blind2Unblind Citing Blind2Unblind @inproceedings{wang2022blind2unblind, tit

demonsjin 32 Jun 23, 2022
The official codes of our CVPR2022 paper: A Differentiable Two-stage Alignment Scheme for Burst Image Reconstruction with Large Shift

TwoStageAlign The official codes of our CVPR2022 paper: A Differentiable Two-stage Alignment Scheme for Burst Image Reconstruction with Large Shift Pa

Shi Guo 26 Jun 16, 2022
Multi-View Consistent Generative Adversarial Networks for 3D-aware Image Synthesis (CVPR2022)

Multi-View Consistent Generative Adversarial Networks for 3D-aware Image Synthesis Multi-View Consistent Generative Adversarial Networks for 3D-aware

Xuanmeng Zhang 56 Jun 21, 2022
A Text Attention Network for Spatial Deformation Robust Scene Text Image Super-resolution (CVPR2022)

A Text Attention Network for Spatial Deformation Robust Scene Text Image Super-resolution (CVPR2022) https://arxiv.org/abs/2203.09388 Jianqi Ma, Zheto

MA Jianqi, shiki 73 Jun 23, 2022
Official implementation of GraphMask as presented in our paper Interpreting Graph Neural Networks for NLP With Differentiable Edge Masking.

GraphMask This repository contains an implementation of GraphMask, the interpretability technique for graph neural networks presented in our ICLR 2021

Michael Schlichtkrull 27 Jun 10, 2022