High-Fidelity Pluralistic Image Completion with Transformers (ICCV 2021)

Overview

Image Completion Transformer (ICT)

Project Page | Paper (ArXiv) | Pre-trained Models | Supplemental Material

This repository is the official pytorch implementation of our ICCV 2021 paper, High-Fidelity Pluralistic Image Completion with Transformers.

Ziyu Wan1, Jingbo Zhang1, Dongdong Chen2, Jing Liao1
1City University of Hong Kong, 2Microsoft Cloud AI

🎈 Prerequisites

  • Python >=3.6
  • PyTorch >=1.6
  • NVIDIA GPU + CUDA cuDNN
pip install -r requirements.txt

To directly inference, first download the pretrained models from Dropbox, then

cd ICT
wget -O ckpts_ICT.zip https://www.dropbox.com/s/cqjgcj0serkbdxd/ckpts_ICT.zip?dl=1
unzip ckpts_ICT.zip

Some tips:

  • Masks should be binarized.
  • The extensions of images and masks should be .png.
  • The model is trained for 256x256 input resolution only.
  • Make sure that the downsampled (32x32 or 48x48) mask could cover all the regions you want to fill. If not, dilate the mask.

🌟 Pipeline

Why transformer?

Compared with traditional CNN-based methods, transformers have better capability in understanding shape and geometry.

🚀 Training

1) Transformer

cd Transformer
python main.py --name [exp_name] --ckpt_path [save_path] \
               --data_path [training_image_path] \
               --validation_path [validation_image_path] \
               --mask_path [mask_path] \
               --BERT --batch_size 64 --train_epoch 100 \
               --nodes 1 --gpus 8 --node_rank 0 \
               --n_layer [transformer_layer #] --n_embd [embedding_dimension] \
               --n_head [head #] --ImageNet --GELU_2 \
               --image_size [input_resolution]

Notes of transformer:

  • --AMP: Reduce the memory cost while training, but sometimes will lead to NAN.
  • --use_ImageFolder: Enable this option while training on ImageNet
  • --random_stroke: Generate the mask on-the-fly.
  • Our code is also ready for training on multiple machines.

2) Guided Upsampling

cd Guided_Upsample
python train.py --model 2 --checkpoints [save_path] \
                --config_file ./config_list/config_template.yml \
                --Generator 4 --use_degradation_2

Notes of guided upsampling:

  • --use_degradation_2: Bilinear downsampling. Try to match the transformer training.
  • --prior_random_degree: Stochastically deviate the sequence elements by K nearest neighbour.
  • Modify the provided config template according to your own training environments.
  • Training the upsample part won't cost many GPUs.

Inference

We provide very covenient and neat script for inference.

python run.py --input_image [test_image_folder] \
              --input_mask [test_mask_folder] \
              --sample_num 1  --save_place [save_path] \
              --ImageNet --visualize_all

Notes of inference:

  • --sample_num: How many completion results do you want?
  • --visualize_all: You could save each output result via disabling this option.
  • --ImageNet --FFHQ --Places2_Nature: You must enable one option to select corresponding ckpts.
  • Please use absolute path.

More results

FFHQ

Places2

ImageNet

To Do

  • Release training code
  • Release testing code
  • Release pre-trained models
  • Add Google Colab

📔 Citation

If you find our work useful for your research, please consider citing the following papers :)

@article{wan2021high,
  title={High-Fidelity Pluralistic Image Completion with Transformers},
  author={Wan, Ziyu and Zhang, Jingbo and Chen, Dongdong and Liao, Jing},
  journal={arXiv preprint arXiv:2103.14031},
  year={2021}
}

The real-world application of image inpainting is also ready! Try and cite our old photo restoration algorithm here.

@inproceedings{wan2020bringing,
title={Bringing Old Photos Back to Life},
author={Wan, Ziyu and Zhang, Bo and Chen, Dongdong and Zhang, Pan and Chen, Dong and Liao, Jing and Wen, Fang},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={2747--2757},
year={2020}
}

💡 Acknowledgments

This repo is built upon minGPT and Edge-Connect. We also thank the provided cluster centers from OpenAI.

📨 Contact

This repo is currently maintained by Ziyu Wan (@Raywzy) and is for academic research use only. Discussions and questions are welcome via [email protected].

Comments
  • Minimum recommended GPUs

    Minimum recommended GPUs

    Hello, what do you think is the minimum recommended GPU specs (memory etc) for good performance, both for training on a new dataset and for testing the pluralistic completion? Thanks!

    opened by nickk124 7
  • Attention Mask

    Attention Mask

    Hi, thanks for your nice works. There are some details that bothered me. I would appreciate it if you could give me some advice.

    1. BERT choice I noticed --BERT option was be used in all Transformer training and inference processes. Which situation we do not need to select this option?
    2. attention mask As the paper described, transformer model capture the unmasked information to predict the probability distribution of missing regions. image. But in the code CausalSelfAttention , I found model will capture information at all position, and attention filling mask does not be used except input occlusion. How can it guarantee to just pay attention on unmasked information?
    3. auto-regressive As far as I understand, the model generate all masked pixels by a end to end mode rather than auto-regressive mode. During the inference, the model generate one pixel each iteration to improve sampling quality. If it works like I said, how do we guarantee the pixel quality of the first masked position in each iteration during the inference process?

    Thanks again.

    opened by Janspiry 4
  • Why not end-to-end network?

    Why not end-to-end network?

    Thank you for proposing this good idea of using Transformer as a priori information! But why not use an end-to-end network for training, is it because the effect is not good?

    opened by Monalissaa 2
  • Obtaining completion probability maps (from paper)

    Obtaining completion probability maps (from paper)

    Hello,

    First of all, super cool model, and thanks for being so helpful with past questions. I was just wondering specifically how you generated pixel-wise completion probability maps as in Fig. 9 of your paper (I get how it's done in theory, I just wanted to see code if possible).

    Thanks!

    opened by nickk124 2
  • Pretrained Model link not working

    Pretrained Model link not working

    Hello, thank you for the great work. I tried to download the pretrained model but I get this error:

    --2021-11-23 18:50:20-- https://www.dropbox.com/s/cqjgcj0serkbdxd/ckpts_ICT.zip?dl=1 Resolving www.dropbox.com (www.dropbox.com)... 162.125.3.18, 2620:100:6018:18::a27d:312 Connecting to www.dropbox.com (www.dropbox.com)|162.125.3.18|:443... connected. HTTP request sent, awaiting response... 301 Moved Permanently Location: /s/dl/cqjgcj0serkbdxd/ckpts_ICT.zip [following] --2021-11-23 18:50:20-- https://www.dropbox.com/s/dl/cqjgcj0serkbdxd/ckpts_ICT.zip Reusing existing connection to www.dropbox.com:443. HTTP request sent, awaiting response... 404 Not Found 2021-11-23 18:50:20 ERROR 404: Not Found.

    The link doesn't seem to work.

    opened by AH289 2
  • num_sample in sample_mask function

    num_sample in sample_mask function

    Hi,

    I am confused by the sample_mask function in transformer/utils/util.py, it seems that it does not use the argument num_sample but keeps num_sample=1, is it normal ?

    Moreover, you use top_k=40 but the paper uses top_k=50. What is the best choice ?

    Thanks,

    opened by samuro95 2
  • Transformer train loss leads to Nan

    Transformer train loss leads to Nan

    Hi, @raywzy

    I am trying to train the model on ImageNet with the following setting: --data_path /ILSVRC2012/train --validation_path /ILSVRC2012/val --mask_path /PConv-Keras/irregular_mask/train_mask --batch_size 32 --train_epoch 100 --nodes 1 --gpus 8 --node_rank 0 --n_layer 35 --n_embd 1024 --n_head 8 --GELU_2 --image_size 32 --use_ImageFolder

    But I am getting Nan for train and test loss (screenshot attached). ImageNet

    This is happening for smaller datasets like Paris streetview (#train_images: 14900, #test_images 100) as well (screenshot attached). Paris_streetview

    Any suggestions on how to fix this issue?

    opened by hiyaroy12 2
  • Transformer training problem

    Transformer training problem

    Hello, Congrats for your nice works. I use a 16G GPU, but a single card can only run batch_size 3. I turned on mixed precision training, and the other settings are --n_layer 35 --n_embd 512 --n_head 8, which is the same as your model trained on Places2. So I want to know how do you use 8 GPUs and set the batch_size to 64 to train the transformer model?

    opened by DQiaole 2
  • How much RAM does inference need?

    How much RAM does inference need?

    I tried to get it working in Google Colab, but it does seem that 25GB RAM is not enough and it seems to crash during the first step. How much RAM is expected? unknown

    opened by styler00dollar 2
  • RuntimeError: Expected object of scalar type Long but got scalar type Float for argument

    RuntimeError: Expected object of scalar type Long but got scalar type Float for argument

    Hi, thank you for sharing code.

    I want to run training code on grayscale images, but i got following error.

    # Mask is 12022, # Image is 12022
    # Mask is 12022, # Image is 0
    Warnning: There is no trained model found. An initialized model will be used.
    Warnning: There is no previous optimizer found. An initialized optimizer will be used.
    Resume from Epoch 0
    Traceback (most recent call last):
      File "main.py", line 139, in <module>
        mp.spawn(main_worker, nprocs=opts.gpus, args=(opts,))
      File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
        return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
      File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
        while not context.join():
      File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 150, in join
        raise ProcessRaisedException(msg, error_index, failed_process.pid)
    torch.multiprocessing.spawn.ProcessRaisedException:
    
    -- Process 0 terminated with the following error:
    Traceback (most recent call last):
      File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
        fn(i, *args)
      File "/home/naoki/ICT/Transformer/main.py", line 73, in main_worker
        trainer.train(loaded_ckpt)
      File "/home/naoki/ICT/Transformer/DDP_trainer.py", line 203, in train
        run_epoch('train')
      File "/home/naoki/ICT/Transformer/DDP_trainer.py", line 139, in run_epoch
        logits, loss = model(x, y)
      File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 799, in forward
        output = self.module(*inputs[0], **kwargs[0])
      File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/naoki/ICT/Transformer/models/model.py", line 254, in forward
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
      File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/nn/functional.py", line 2824, in cross_entropy
        return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
    RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss_forward
    

    I checked the size and data type before cross entropy loss. ''' logits: torch.Size([12, 1024, 512]), torch.float32 targets: torch.Size([12, 1024]), torch.float32 '''

    Could you give me how to solve this problem. Thank you in advance.

    opened by naoki7090624 1
  • ImageDataset doesn't exist in datas.dataset (unused anyway)

    ImageDataset doesn't exist in datas.dataset (unused anyway)

    Otherwise I get this error:

    Traceback (most recent call last):
      File "inference.py", line 10, in <module>
        from datas.dataset import ImageDataset
    ImportError: cannot import name 'ImageDataset' from 'datas.dataset' (/content/ICT/Transformer/datas/dataset.py)
    
    opened by josephrocca 1
  • ProcessExitedException: process 0 terminated with signal SIGKILL

    ProcessExitedException: process 0 terminated with signal SIGKILL

    Has anyone encountered this error,torch.multiprocessing.spawn.ProcessExitedException: process 0 terminated with signal SIGKILL. I am wondering a solution.

    opened by CyrilCsy 0
  • Pretrained weight of upsampler of Places does not work well.

    Pretrained weight of upsampler of Places does not work well.

    Hi, @raywzy

    I tried to use both FFHQ and Places2 pretrained weights of upsampler. However, the Upsampler's pretrained weight of Places does not have enough quality. We know that the weight of the generator is trained during 322000 iterations. Do you think that the attached results are correct?

    From left to right, 1st stage output | blended result with masked input image | raw output of upsampler | GT image | masked input image | raw output of upsampler within given mask

    debug_0

    opened by UdonDa 0
  • ### Something Wrong ###

    ### Something Wrong ###

    Trying to evaluate your codes: WSL2 under Windows 10 Nvidia RTX3090

    Upon start provided script for inference (regardless of using ckpts) , from the very beginning got a message:

    ### Something Wrong ###
      0%|  
    

    After that calculation continue. If "--FFHQ" or "--Places2_Nature" specified - inference finished with no error.

    However if " ImageNet" specified - inference finished with an error:

    raise AssertionError("Invalid device id")
    AssertionError: Invalid device id
    

    NVCC report:

    nvcc: NVIDIA (R) Cuda compiler driver
    Copyright (c) 2005-2022 NVIDIA Corporation
    Built on Tue_May__3_19:00:59_Pacific_Daylight_Time_2022
    Cuda compilation tools, release 11.7, V11.7.64
    Build cuda_11.7.r11.7/compiler.31294372_0
    

    Tried to change line 44 in "run.py" from CUDA_VISIBLE_DEVICES=0,1 to CUDA_VISIBLE_DEVICES=0

    -- no luck.

    Found the reason for message:

    ### Something Wrong ###
      0%|  
    

    That happened due to the different file image/mask names (they should be identical). However the rest of the issue still exists. Beside of that after finishing inference (regardless specified ckpts) "output" sub folder in the folder "Guided_Upsample" got created but it is empty. "output" sub folder in the folder "Transformer" got created and consists generated image with 32 or 48 pixel.

    opened by semel1 0
  • Enable --random_stroke option but the mask path is still the default

    Enable --random_stroke option but the mask path is still the default

    Thanks for the nice code! I tried to enable "--random_stroke" so the masks will be generated on the fly. But it seems that the "--mask_path" still needs input, otherwise it will point to the default path which is not on my PC.

    opened by zhenzey 0
  • About the missing parameter 'loader' in Guided Upsampling when inference is done

    About the missing parameter 'loader' in Guided Upsampling when inference is done

    Thank you for sharing. We would like to use your model as a baseline for comparison. However, there is currently a little problem based on the inference code you provided, as shown in the below, how can I solve this situation?

    QQ图片20220610153610

    opened by ScarletBlaze 2
Owner
Ziyu Wan
Ph.D Student @ City University of Hong Kong
Ziyu Wan
Implementation for HFGI: High-Fidelity GAN Inversion for Image Attribute Editing

HFGI: High-Fidelity GAN Inversion for Image Attribute Editing High-Fidelity GAN Inversion for Image Attribute Editing Update: We released the inferenc

Tengfei Wang 371 Dec 30, 2022
SCI-AIDE : High-fidelity Few-shot Histopathology Image Synthesis for Rare Cancer Diagnosis

SCI-AIDE : High-fidelity Few-shot Histopathology Image Synthesis for Rare Cancer Diagnosis Pretrained Models In this work, we created synthetic tissue

Emirhan Kurtuluş 1 Feb 7, 2022
《Towards High Fidelity Face Relighting with Realistic Shadows》(CVPR 2021)

Towards High Fidelity Face-Relighting with Realistic Shadows Andrew Hou, Ze Zhang, Michel Sarkis, Ning Bi, Yiying Tong, Xiaoming Liu. In CVPR, 2021. T

null 114 Dec 10, 2022
Deep generative modeling for time-stamped heterogeneous data, enabling high-fidelity models for a large variety of spatio-temporal domains.

Neural Spatio-Temporal Point Processes [arxiv] Ricky T. Q. Chen, Brandon Amos, Maximilian Nickel Abstract. We propose a new class of parameterizations

Facebook Research 75 Dec 19, 2022
HiFi-GAN: High Fidelity Denoising and Dereverberation Based on Speech Deep Features in Adversarial Networks

HiFiGAN Denoiser This is a Unofficial Pytorch implementation of the paper HiFi-GAN: High Fidelity Denoising and Dereverberation Based on Speech Deep F

Rishikesh (ऋषिकेश) 134 Dec 27, 2022
HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis

HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis Jungil Kong, Jaehyeon Kim, Jaekyoung Bae In our paper, we p

Rishikesh (ऋषिकेश) 31 Dec 8, 2022
Tensorflow python implementation of "Learning High Fidelity Depths of Dressed Humans by Watching Social Media Dance Videos"

Learning High Fidelity Depths of Dressed Humans by Watching Social Media Dance Videos This repository is the official tensorflow python implementation

Yasamin Jafarian 287 Jan 6, 2023
UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation

UnivNet UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation. Training python train.py --c

Rishikesh (ऋषिकेश) 55 Dec 26, 2022
Unofficial PyTorch Implementation of UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation

UnivNet UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation This is an unofficial PyTorch

MINDs Lab 170 Jan 4, 2023
This repository contains the code for using the H3DS dataset introduced in H3D-Net: Few-Shot High-Fidelity 3D Head Reconstruction

H3DS Dataset This repository contains the code for using the H3DS dataset introduced in H3D-Net: Few-Shot High-Fidelity 3D Head Reconstruction Access

Crisalix 72 Dec 10, 2022
Unofficial PyTorch Implementation of UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation

UnivNet UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation This is an unofficial PyTorch

MINDs Lab 54 Aug 30, 2021
A two-stage U-Net for high-fidelity denoising of historical recordings

A two-stage U-Net for high-fidelity denoising of historical recordings Official repository of the paper (not submitted yet): E. Moliner and V. Välimäk

Eloi Moliner Juanpere 57 Jan 5, 2023
PyTorch Implementation of DiffGAN-TTS: High-Fidelity and Efficient Text-to-Speech with Denoising Diffusion GANs

DiffGAN-TTS - PyTorch Implementation PyTorch implementation of DiffGAN-TTS: High

Keon Lee 157 Jan 1, 2023
Parallel and High-Fidelity Text-to-Lip Generation; AAAI 2022 ; Official code

Parallel and High-Fidelity Text-to-Lip Generation This repository is the official PyTorch implementation of our AAAI-2022 paper, in which we propose P

Zhying 77 Dec 21, 2022
[ICCV 2021 Oral] SnowflakeNet: Point Cloud Completion by Snowflake Point Deconvolution with Skip-Transformer

This repository contains the source code for the paper SnowflakeNet: Point Cloud Completion by Snowflake Point Deconvolution with Skip-Transformer (ICCV 2021 Oral). The project page is here.

AllenXiang 65 Dec 26, 2022
"3D Human Texture Estimation from a Single Image with Transformers", ICCV 2021

Texformer: 3D Human Texture Estimation from a Single Image with Transformers This is the official implementation of "3D Human Texture Estimation from

XiangyuXu 193 Dec 5, 2022
From Fidelity to Perceptual Quality: A Semi-Supervised Approach for Low-Light Image Enhancement (CVPR'2020)

Under-exposure introduces a series of visual degradation, i.e. decreased visibility, intensive noise, and biased color, etc. To address these problems, we propose a novel semi-supervised learning approach for low-light image enhancement.

Yang Wenhan 117 Jan 3, 2023
ICRA 2021 "Towards Precise and Efficient Image Guided Depth Completion"

PENet: Precise and Efficient Depth Completion This repo is the PyTorch implementation of our paper to appear in ICRA2021 on "Towards Precise and Effic

null 232 Dec 25, 2022
[ICLR 2021, Spotlight] Large Scale Image Completion via Co-Modulated Generative Adversarial Networks

Large Scale Image Completion via Co-Modulated Generative Adversarial Networks, ICLR 2021 (Spotlight) Demo | Paper [NEW!] Time to play with our interac

Shengyu Zhao 373 Jan 2, 2023