Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding
by Qiaole Dong*, Chenjie Cao*, Yanwei Fu
Paper and Supplemental Material (arXiv)
Pipeline
Click to expand
The overview of our ZITS. At first, the TSR model is used to restore structures with low resolutions. Then the simple CNN based upsampler is leveraged to upsample edge and line maps. Moreover, the upsampled sketch space is encoded and added to the FTR through ZeroRA to restore the textures.
TO DO
We have updated weights of TSR!
Our project page is available at https://dqiaole.github.io/ZITS_inpainting/.
- Releasing inference codes.
- Releasing pre-trained moodel.
- Releasing training codes.
Preparation
Click to expand
-
Preparing the environment:
as there are some bugs when using GP loss with DDP (link), we strongly recommend installing Apex without CUDA extensions via torch1.9.0 for the multi-gpu training
conda create -n train_env python=3.6 conda activate train_env pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html pip install -r requirement.txt git clone https://github.com/NVIDIA/apex cd apex pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" ./
-
For training, MST provide irregular and segmentation masks (download) with different masking rates. And you should define the mask file list before the training as in MST.
-
Download the pretrained masked wireframe detection model to the './ckpt' fold: LSM-HAWP (MST ICCV2021 retrained from HAWP CVPR2020).
-
Prepare the wireframes:
as the MST train the LSM-HAWP in Pytorch 1.3.1 and it causes problem (link) when tested in Pytorch 1.9, we recommand to inference the lines(wireframes) with torch==1.3.1. If the line detection is not based on torch1.3.1, the performance may drop a little.
conda create -n wireframes_inference_env python=3.6 conda activate wireframes_inference_env pip install torch==1.3.1 torchvision==0.4.2 pip install -r requirement.txt
then extract wireframes with following code
python lsm_hawp_inference.py --ckpt_path <best_lsm_hawp.pth> --input_path <input image path> --output_path <output image path> --gpu_ids '0'
-
If you need to train the model, please download the pretrained models for perceptual loss, provided by LaMa:
mkdir -p ade20k/ade20k-resnet50dilated-ppm_deepsup/ wget -P ade20k/ade20k-resnet50dilated-ppm_deepsup/ http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth
Eval
Click to expand
Download pretrained models on Places2 here.
Link for BaiduDrive, password:qnm5
Batch Test
For batch test, you need to complete steps 3 and 4 above.
Put the pretrained models to the './ckpt' fold. Then modify the config file according to you image, mask and wireframes path.
Test on 256 images:
conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2 --config_file ./config_list/config_ZITS_places2.yml --GPU_ids '0'
Test on 512 images:
conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2_hr --config_file ./config_list/config_ZITS_HR_places2.yml --GPU_ids '0'
Single Image Test
Note: For single image test, environment 'wireframes_inference_env' in step 4 is recommended for a better line detection. This code only supports squared images (or they will be center cropped).
conda activate wireframes_inference_env
python single_image_test.py --path <ckpt_path> --config_file <config_path> \
--GPU_ids '0' --img_path ./image.png --mask_path ./mask.png --save_path ./
Training
Click to expand
Training TSR
python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
--train_line_path [training_wireframes_path] \
--mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
--train_epoch 12 --validation_path [validation_data_path] \
--val_line_path [validation_wireframes_path] \
--valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP
python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
--train_line_path [training_wireframes_path] \
--mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
--train_epoch 15 --validation_path [validation_data_path] \
--val_line_path [validation_wireframes_path] \
--valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP --MaP
Train SSU
We recommend to use the pretrained SSU. You can also train your SSU refered to https://github.com/ewrfcas/StructureUpsampling.
Training LaMa First
python FTR_train.py --nodes 1 --gpus 1 --GPU_ids '0' --path ./ckpt/lama_places2 \
--config_file ./config_list/config_LAMA.yml --lama
Training FTR
256:
python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2 \
--config_file ./config_list/config_ZITS_places2.yml --DDP
256~512:
python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2_HR \
--config_file ./config_list/config_ZITS_HR_places2.yml --DDP
More 1K Results
Acknowledgments
Cite
If you found our program helpful, please consider citing:
@inproceedings{dong2022incremental,
title={Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding},
author={Qiaole Dong and Chenjie Cao and Yanwei Fu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}