Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding (CVPR2022)

Overview

Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding

by Qiaole Dong*, Chenjie Cao*, Yanwei Fu

Paper and Supplemental Material (arXiv)

LICENSE

Pipeline

Click to expand

The overview of our ZITS. At first, the TSR model is used to restore structures with low resolutions. Then the simple CNN based upsampler is leveraged to upsample edge and line maps. Moreover, the upsampled sketch space is encoded and added to the FTR through ZeroRA to restore the textures.

TO DO

We have updated weights of TSR!

Our project page is available at https://dqiaole.github.io/ZITS_inpainting/.

  • Releasing inference codes.
  • Releasing pre-trained moodel.
  • Releasing training codes.

Preparation

Click to expand
  1. Preparing the environment:

    as there are some bugs when using GP loss with DDP (link), we strongly recommend installing Apex without CUDA extensions via torch1.9.0 for the multi-gpu training

    conda create -n train_env python=3.6
    conda activate train_env
    pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
    pip install -r requirement.txt
    git clone https://github.com/NVIDIA/apex
    cd apex
    pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" ./
    
  2. For training, MST provide irregular and segmentation masks (download) with different masking rates. And you should define the mask file list before the training as in MST.

  3. Download the pretrained masked wireframe detection model to the './ckpt' fold: LSM-HAWP (MST ICCV2021 retrained from HAWP CVPR2020).

  4. Prepare the wireframes:

    as the MST train the LSM-HAWP in Pytorch 1.3.1 and it causes problem (link) when tested in Pytorch 1.9, we recommand to inference the lines(wireframes) with torch==1.3.1. If the line detection is not based on torch1.3.1, the performance may drop a little.

    conda create -n wireframes_inference_env python=3.6
    conda activate wireframes_inference_env
    pip install torch==1.3.1 torchvision==0.4.2
    pip install -r requirement.txt
    

    then extract wireframes with following code

    python lsm_hawp_inference.py --ckpt_path <best_lsm_hawp.pth> --input_path <input image path> --output_path <output image path> --gpu_ids '0'
    
  5. If you need to train the model, please download the pretrained models for perceptual loss, provided by LaMa:

    mkdir -p ade20k/ade20k-resnet50dilated-ppm_deepsup/
    wget -P ade20k/ade20k-resnet50dilated-ppm_deepsup/ http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth
    

Eval

Click to expand

Download pretrained models on Places2 here.

Link for BaiduDrive, password:qnm5

Batch Test

For batch test, you need to complete steps 3 and 4 above.

Put the pretrained models to the './ckpt' fold. Then modify the config file according to you image, mask and wireframes path.

Test on 256 images:

conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2 --config_file ./config_list/config_ZITS_places2.yml --GPU_ids '0'

Test on 512 images:

conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2_hr --config_file ./config_list/config_ZITS_HR_places2.yml --GPU_ids '0'

Single Image Test

Note: For single image test, environment 'wireframes_inference_env' in step 4 is recommended for a better line detection. This code only supports squared images (or they will be center cropped).

conda activate wireframes_inference_env
python single_image_test.py --path <ckpt_path> --config_file <config_path> \
 --GPU_ids '0' --img_path ./image.png --mask_path ./mask.png --save_path ./

Training

Click to expand

⚠️ Warning: The training codes is not fully tested yet after refactoring

Training TSR

python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
 --train_line_path [training_wireframes_path] \
 --mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
 --train_epoch 12 --validation_path [validation_data_path] \
 --val_line_path [validation_wireframes_path] \
 --valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP
python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
 --train_line_path [training_wireframes_path] \
 --mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
 --train_epoch 15 --validation_path [validation_data_path] \
 --val_line_path [validation_wireframes_path] \
 --valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP --MaP

Train SSU

We recommend to use the pretrained SSU. You can also train your SSU refered to https://github.com/ewrfcas/StructureUpsampling.

Training LaMa First

python FTR_train.py --nodes 1 --gpus 1 --GPU_ids '0' --path ./ckpt/lama_places2 \
--config_file ./config_list/config_LAMA.yml --lama

Training FTR

256:

python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2 \
--config_file ./config_list/config_ZITS_places2.yml --DDP

256~512:

python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2_HR \
--config_file ./config_list/config_ZITS_HR_places2.yml --DDP

More 1K Results

Click to expand

Acknowledgments

Cite

If you found our program helpful, please consider citing:

@inproceedings{dong2022incremental,
      title={Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding}, 
      author={Qiaole Dong and Chenjie Cao and Yanwei Fu},
      booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
      year={2022}
}
Comments
  • lsm_hawp_inference.py_result_bad

    lsm_hawp_inference.py_result_bad

    I try to use the lsm_hawp_inference.py to generate the .pkl of my dataset (place365). I used the best_lsm_hawp.pth which you provided. But the result is really bad. I tried reduce threshold=0.8 > 0.5 but it still has bad result.

    Do you have the best_palce365_lsm_hawp.pth? Or how do we train our own hawp.

    The iamge is the sample from training.(14001.jpg) image

    opened by bobo0303 13
  • Is there a demo code to make inference on custom image and mask

    Is there a demo code to make inference on custom image and mask

    Hi, I tried single_image_test.py, but it is hard coded for Places 365 Standard. Is there any simpler demo code to show the results based on a pair of inputs such as image and corresponding mask?

    opened by yijingru 6
  • inpainting_metrics.py中ValueError: axes don't match array错误

    inpainting_metrics.py中ValueError: axes don't match array错误

    作者你好,在文件的这行代码出现上述问题 image 如果此时传入的是indoor数据集的原图像,不是256*256大小的,根据博客:https://stackoverflow.com/questions/37747021/create-numpy-array-of-images 给出的方法,对图像进行处理后可以正常运行,但是不知道这样是否会对结果产生影响。 image 我想知道你们有遇到过这个问题吗?能否给我一点建议呢?期待您的回复,谢谢

    opened by Ellohiye 4
  • Single image test

    Single image test

    你好,作者,你做的工作非常棒,只是我在进行源码测试时有一些疑问:在下面的配置设置中 python single_image_test.py --path <ckpt_path> --config_file <config_path>
    --GPU_ids '0' --img_path ./image.png --mask_path ./mask.png --save_path ./ 权重path使用哪个呢?config_file使用哪个文件呢?我自己设置的设置如下: python single_image_test.py --path ./ckpt/zits_places2_hr --config_file ./config_list/config_ZITS_HR_places2.yml --GPU_ids '0' --img_path ./test_i/img1.png --mask_path ./test_i/mask1.png --save_path ./test_i/ 但是出现了下面的错误: Traceback (most recent call last): File "single_image_test.py", line 322, in model = ZITS(config, 0, 0, True) File "D:\pythonProject\7_4\inpaint\ZITS_inpainting-main\src\FTR_trainer.py", line 296, in init min_sigma=min_sigma, max_sigma=max_sigma) File "D:\pythonProject\7_4\inpaint\ZITS_inpainting-main\datasets\dataset_FTR.py", line 178, in init f = open(flist, 'r') FileNotFoundError: [Errno 2] No such file or directory: '/home/wmlce/places365_standard/places2_all/test_sub_list.txt'

    单图测试也需要跟数据集一样的设置吗?希望能将测试步骤更加详细一些。希望能回复,非常感谢。希望能将测试步骤更加详细一些。

    opened by CodeMadUser 4
  • Bad results

    Bad results

    I am getting some very poor results. I am using the single_image script and resizing images to 512,512

    image image image

    Can some of the images + masks from the showed resuts can be shared? This way I could verify if I did something weird

    opened by mhashas 4
  • 可否预上传一份pth样本,直接调试

    可否预上传一份pth样本,直接调试

    config_ZITS_places2.yml

    transformer_ckpt_path: './ckpt/best_transformer_places2.pth' gen_weights_path0: './ckpt/lama_places2/InpaintingModel_gen.pth' # Not required at the time of eval dis_weights_path0: './ckpt/lama_places2/InpaintingModel_dis.pth' # Not required at the time of eval structure_upsample_path: './ckpt/StructureUpsampling.pth'

    D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\src\models\FTR_model.py

    data = torch.load(config.structure_upsample_path, map_location='cpu')
    

    发生异常: AttributeError 'NoneType' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

    During handling of the above exception, another exception occurred:

    File "D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\src\models\FTR_model.py", line 165, in init data = torch.load(config.structure_upsample_path, map_location='cpu') File "D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\src\models\FTR_model.py", line 427, in init super().init(*args, gpu=gpu, name='InpaintingModel', rank=rank, test=test, **kwargs) File "D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\src\FTR_trainer.py", line 256, in init self.inpaint_model = DefaultInpaintingTrainingModule(config, gpu=gpu, rank=rank, test=test, **kwargs).to(gpu) File "D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\single_image_test.py", line 323, in model = ZITS(config, 0, 0, True)


    PS D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main> & 'D:\pm\python\python38\python.exe' 'c:\Users\Administrator.vscode\extensions\ms-python.python-2022.4.1\pythonFiles\lib\python\debugpy\launcher' '40191' '--' 'd:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\single_image_test.py' '--path=D:\pm\python\lama\LaMa_models\lama-places\lama-fourier\models' '--config_file=D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\config_list\config_ZITS_places2.yml' '--GPU_ids=-1' '--img_path=D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\imgs\y\i1.png' '--mask_path=D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\imgs\mask\i1.png' '--save_path=D:\pm\python\inpaint\ZITS_inpainting-main\ZITS_inpainting-main\imgs' Backend TkAgg is interactive backend. Turning interactive mode on. BaseInpaintingTrainingModule init called Loading InpaintingModel StructureUpsampling...

    opened by time888 4
  • ERROR: Could not find a version that satisfies the requirement torch==1.3.1

    ERROR: Could not find a version that satisfies the requirement torch==1.3.1

    Hi, You recommend to inference the wireframes with torch 1.3.1 on README, but could not find the version by pip.

    ERROR: Could not find a version that satisfies the requirement torch==1.3.1 (from versions: 1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2, 1.11.0)
    ERROR: No matching distribution found for torch==1.3.1
    

    How did you install the old version?

    opened by naoki7090624 4
  • Question about loss and activation function

    Question about loss and activation function

    Hi, I have questions about activation function and loss.

    1. Why you calculate the loss before the activation function? According to your code, the cross entropy loss is calculated before sigmoid function. In the general CNN, I think the loss is calculated after the activation. Could you tell me why.

    2. Why you use only cross entropy loss? According to your code, only the cross entropy loss is calculated in the TSR. I wonder if you could use other losses (L1 loss, feature matching loss) after upsampling because there are convolution layers after transformer blocks.

    opened by naoki7090624 4
  • Access to the pre-trained model

    Access to the pre-trained model

    Loved the paper! The results compare to LaMa are amazing. Can I have access to the lightest pre-trained model? (Benchmarking on mobile devices)

    Best regards, Roi

    opened by roimulia2 4
  • Pretrained Indoor Model

    Pretrained Indoor Model

    Hi, Can you upload the pretrained Indoor data model - the results of which you share in your paper? Also, can you share the trained models of the comparative methods you show results for in your paper?

    Thank you.

    opened by toshi2k2 3
  • wireframe model is irrelevant

    wireframe model is irrelevant

    Hi,

    I've been playing quite a bit with your model due to the amazing results. Something that i've noticed is that the wireframe model is irrelevant. If I return a full zeros tensor of lines_tensor of the same shape as the actual output in wf_inference_test, I get the same final outputs. Is there a bug somewhere?

    To replicate:

    return torch.zeros_like(lines_tensor.detach()) in wf_inference_test

    Update:

    It seems that also the edges seem to be useless.

      batch["line_256"] = torch.zeros_like(batch["mask_256"])
      batch["line"] = torch.zeros_like(batch["mask_512"]) 
      batch["edge"] = torch.zeros_like(batch["mask_512"])
    

    Do this change gives me the same results

    Let me know if im doing something wrong

    opened by mhashas 3
  • AttributeError: module 'torch.distributed' has no attribute '_reduce_scatter_base'

    AttributeError: module 'torch.distributed' has no attribute '_reduce_scatter_base'

    运行TSR_train.py 时出现错误 File "TSR_train.py", line 7, in from src.TSR_trainer import TrainerConfig, TrainerForContinuousEdgeLine, TrainerForEdgeLineFinetune File "D:\AIworkspace\ZITS_inpainting-main\src\TSR_trainer.py", line 14, in from apex import amp File "D:\Users\lcx\anaconda3\envs\train_env\lib\site-packages\apex_init_.py", line 27, in from . import transformer File "D:\Users\lcx\anaconda3\envs\train_env\lib\site-packages\apex\transformer_init_.py", line 4, in from apex.transformer import pipeline_parallel File "D:\Users\lcx\anaconda3\envs\train_env\lib\site-packages\apex\transformer\pipeline_parallel_init_.py", line 1, in from apex.transformer.pipeline_parallel.schedules import get_forward_backward_func File "D:\Users\lcx\anaconda3\envs\train_env\lib\site-packages\apex\transformer\pipeline_parallel\schedules_init_.py", line 3, in from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import ( File "D:\Users\lcx\anaconda3\envs\train_env\lib\site-packages\apex\transformer\pipeline_parallel\schedules\fwd_bwd_no_pipelining.py", line 10, in from apex.transformer.pipeline_parallel.schedules.common import Batch File "D:\Users\lcx\anaconda3\envs\train_env\lib\site-packages\apex\transformer\pipeline_parallel\schedules\common.py", line 14, in from apex.transformer.tensor_parallel.layers import ( File "D:\Users\lcx\anaconda3\envs\train_env\lib\site-packages\apex\transformer\tensor_parallel_init_.py", line 21, in from apex.transformer.tensor_parallel.layers import ( File "D:\Users\lcx\anaconda3\envs\train_env\lib\site-packages\apex\transformer\tensor_parallel\layers.py", line 32, in from apex.transformer.tensor_parallel.mappings import ( File "D:\Users\lcx\anaconda3\envs\train_env\lib\site-packages\apex\transformer\tensor_parallel\mappings.py", line 29, in torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base AttributeError: module 'torch.distributed' has no attribute '_reduce_scatter_base' 我的环境是 torch=1.9.0+cu111 cuda=11.1 , 请问作者如何解决? 谢谢

    opened by pipiwawa 1
Owner
Qiaole Dong
Qiaole Dong
PyTorch implementation of Rethinking Positional Encoding in Language Pre-training

TUPE PyTorch implementation of Rethinking Positional Encoding in Language Pre-training. Quickstart Clone this repository. git clone https://github.com

Jake Tae 5 Jan 27, 2022
AOT-GAN for High-Resolution Image Inpainting (codebase for image inpainting)

AOT-GAN for High-Resolution Image Inpainting Arxiv Paper | AOT-GAN: Aggregated Contextual Transformations for High-Resolution Image Inpainting Yanhong

Multimedia Research 205 Nov 30, 2022
My implementation of Image Inpainting - A deep learning Inpainting model

Image Inpainting What is Image Inpainting Image inpainting is a restorative process that allows for the fixing or removal of unwanted parts within ima

Joshua V Evans 1 Dec 12, 2021
Official code for "Towards An End-to-End Framework for Flow-Guided Video Inpainting" (CVPR2022)

E2FGVI (CVPR 2022) English | 简体中文 This repository contains the official implementation of the following paper: Towards An End-to-End Framework for Flo

Media Computing Group @ Nankai University 515 Dec 5, 2022
[ACMMM 2021 Oral] Enhanced Invertible Encoding for Learned Image Compression

InvCompress Official Pytorch Implementation for "Enhanced Invertible Encoding for Learned Image Compression", ACMMM 2021 (Oral) Figure: Our framework

null 96 Nov 30, 2022
CVPR 2021: "Generating Diverse Structure for Image Inpainting With Hierarchical VQ-VAE"

Diverse Structure Inpainting ArXiv | Papar | Supplementary Material | BibTex This repository is for the CVPR 2021 paper, "Generating Diverse Structure

null 152 Nov 4, 2022
[ICCV'2021] Image Inpainting via Conditional Texture and Structure Dual Generation

[ICCV'2021] Image Inpainting via Conditional Texture and Structure Dual Generation

Xiefan Guo 119 Nov 28, 2022
U-Net Implementation: Convolutional Networks for Biomedical Image Segmentation" using the Carvana Image Masking Dataset in PyTorch

U-Net Implementation By Christopher Ley This is my interpretation and implementation of the famous paper "U-Net: Convolutional Networks for Biomedical

Christopher Ley 1 Jan 6, 2022
MAT: Mask-Aware Transformer for Large Hole Image Inpainting

MAT: Mask-Aware Transformer for Large Hole Image Inpainting (CVPR2022, Oral) Wenbo Li, Zhe Lin, Kun Zhou, Lu Qi, Yi Wang, Jiaya Jia [Paper] News This

null 234 Nov 28, 2022
Multi-Scale Vision Longformer: A New Vision Transformer for High-Resolution Image Encoding

Vision Longformer This project provides the source code for the vision longformer paper. Multi-Scale Vision Longformer: A New Vision Transformer for H

Microsoft 206 Nov 20, 2022
Implementation of Convolutional enhanced image Transformer

CeiT : Convolutional enhanced image Transformer This is an unofficial PyTorch implementation of Incorporating Convolution Designs into Visual Transfor

Rishikesh (ऋषिकेश) 80 Nov 10, 2022
Video Frame Interpolation with Transformer (CVPR2022)

VFIformer Official PyTorch implementation of our CVPR2022 paper Video Frame Interpolation with Transformer Dependencies python >= 3.8 pytorch >= 1.8.0

DV Lab 61 Nov 30, 2022
TopFormer: Token Pyramid Transformer for Mobile Semantic Segmentation, CVPR2022

TopFormer: Token Pyramid Transformer for Mobile Semantic Segmentation Paper Links: TopFormer: Token Pyramid Transformer for Mobile Semantic Segmentati

Hust Visual Learning Team 243 Nov 28, 2022
Official implementation of GraphMask as presented in our paper Interpreting Graph Neural Networks for NLP With Differentiable Edge Masking.

GraphMask This repository contains an implementation of GraphMask, the interpretability technique for graph neural networks presented in our ICLR 2021

Michael Schlichtkrull 29 Sep 2, 2022
A Closer Look at Invalid Action Masking in Policy Gradient Algorithms

A Closer Look at Invalid Action Masking in Policy Gradient Algorithms This repo contains the source code to reproduce the results in the paper A Close

Costa Huang 69 Dec 5, 2022
Code & Data for the Paper "Time Masking for Temporal Language Models", WSDM 2022

Time Masking for Temporal Language Models This repository provides a reference implementation of the paper: Time Masking for Temporal Language Models

Guy Rosin 12 Oct 19, 2022
This is an official implementation of the CVPR2022 paper "Blind2Unblind: Self-Supervised Image Denoising with Visible Blind Spots".

Blind2Unblind: Self-Supervised Image Denoising with Visible Blind Spots Blind2Unblind Citing Blind2Unblind @inproceedings{wang2022blind2unblind, tit

demonsjin 57 Dec 4, 2022
The official codes of our CVPR2022 paper: A Differentiable Two-stage Alignment Scheme for Burst Image Reconstruction with Large Shift

TwoStageAlign The official codes of our CVPR2022 paper: A Differentiable Two-stage Alignment Scheme for Burst Image Reconstruction with Large Shift Pa

Shi Guo 32 Oct 11, 2022