ILVR: Conditioning Method for Denoising Diffusion Probabilistic Models (ICCV 2021 Oral)

Overview

ILVR + ADM

This is the implementation of ILVR: Conditioning Method for Denoising Diffusion Probabilistic Models (ICCV 2021 Oral).

This repository is heavily based on improved diffusion and guided diffusion. We use PyTorch-Resizer for resizing function.

Overview

ILVR is a learning-free method for controlling the generation of unconditional DDPMs. ILVR refines each generation step with low-frequency component of purturbed reference image. Our method enables various tasks (image translation, paint-to-image, editing with scribbles) with only a single model trained on a target dataset.

image

Download pre-trained models

Create a folder models/ and download model checkpoints into it. Here are the unconditional models trained on FFHQ and AFHQ-dog:

These models have seen 10M and 4M images respectively. You may also try with models from guided diffusion.

ILVR Sampling

First, set PYTHONPATH variable to point to the root of the repository.

export PYTHONPATH=$PYTHONPATH:$(pwd)

Then, place your input image into a folder ref_imgs/.

Run the ilvr_sample.py script. Specify the folder where you want to save the output in --save_dir.

Here, we provide flags for sampling from above models. Feel free to change --down_N and --range_t to adapt downsampling factor and conditioning range from the paper.

Refer to improved diffusion for --timestep_respacing flag.

python scripts/ilvr_sample.py  --attention_resolutions 16 --class_cond False --diffusion_steps 1000 --dropout 0.0 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 128 --num_head_channels 64 --num_res_blocks 1 --resblock_updown True --use_fp16 False --use_scale_shift_norm True --timestep_respacing 100 --model_path models/ffhq_10m.pt --base_samples ref_imgs/face --down_N 32 --range_t 20 --save_dir output

ILVR sampling is implemented in p_sample_loop_progressive of guided-diffusion/gaussian_diffusion.py

Results

These are samples generated with N=8 and 16:

a

b

These are cat-to-dog samples generated with N=32:

c

Note

This repo is re-implemention of our method on guided diffusion. Our initial implementation of the paper is based on denoising-diffusion-pytorch.

Comments
  • Error in ilvr_sample after loading my trained model

    Error in ilvr_sample after loading my trained model

    Hi,

    I trained a model from scratch using my own dataset. After training I ended up with checkpoint files like ema_0.9999_000010.pt, model000010.pt and opt000010.pt.

    flags used for training

    python train_model.py --data_dir /data1/ --image_size 256 --num_channels 128 --num_res_blocks 3 --diffusion_steps 4000 --noise_schedule cosine --lr 1e-4 --batch_size 4 --save_dir /data2/
    

    I used the checkpoint file ema_0.9999_000010.pt for ilvr sampling but it throwed the following error

    flags used for sampling

    python src/models/ILVR_GuidedDiffusion/ilvr_sample.py  --attention_resolutions 16 --class_cond False --diffusion_steps 4000 --dropout 0.0 --image_size 256 --learn_sigma True --noise_schedule cosine --num_channels 128 --num_res_blocks 1 --resblock_updown True --use_fp16 False --use_scale_shift_norm True --timestep_respacing 100 --model_path /data2/ema_0.9999_000010.pt --base_samples ref_imgs/bdd10k --down_N 32 --range_t 20 --save_dir reports/figures/guided
    

    Error

    Logging to reports/figures/guided
    creating model...
    Traceback (most recent call last):
      File "src/models/ILVR_GuidedDiffusion/ilvr_sample.py", line 134, in <module>
        main()
      File "src/models/ILVR_GuidedDiffusion/ilvr_sample.py", line 49, in main
        model.load_state_dict(
      File "/home/vinod/anaconda3/envs/lsgm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in load_state_dict
        raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
    RuntimeError: Error(s) in loading state_dict for UNetModel:
            Missing key(s) in state_dict: "input_blocks.2.0.in_layers.0.weight", "input_blocks.2.0.in_layers.0.bias", "input_blocks.2.0.in_layers.2.weight", "input_blocks.2.0.in_layers.2.bias", "input_blocks.2.0.emb_layers.1.weight", "input_blocks.2.0.emb_layers.1.bias", "input_blocks.2.0.out_layers.0.weight", "input_blocks.2.0.out_layers.0.bias", "input_blocks.2.0.out_layers.3.weight", "input_blocks.2.0.out_layers.3.bias", "input_blocks.4.0.in_layers.0.weight", "input_blocks.4.0.in_layers.0.bias", "input_blocks.4.0.in_layers.2.weight", "input_blocks.4.0.in_layers.2.bias", "input_blocks.4.0.emb_layers.1.weight", "input_blocks.4.0.emb_layers.1.bias", "input_blocks.4.0.out_layers.0.weight", "input_blocks.4.0.out_layers.0.bias", "input_blocks.4.0.out_layers.3.weight", "input_blocks.4.0.out_layers.3.bias", "input_blocks.6.0.in_layers.0.weight", "input_blocks.6.0.in_layers.0.bias", "input_blocks.6.0.in_layers.2.weight", "input_blocks.6.0.in_layers.2.bias", "input_blocks.6.0.emb_layers.1.weight", "input_blocks.6.0.emb_layers.1.bias", "input_blocks.6.0.out_layers.0.weight", "input_blocks.6.0.out_layers.0.bias", "input_blocks.6.0.out_layers.3.weight", "input_blocks.6.0.out_layers.3.bias", "input_blocks.8.0.in_layers.0.weight", "input_blocks.8.0.in_layers.0.bias", "input_blocks.8.0.in_layers.2.weight", "input_blocks.8.0.in_layers.2.bias", "input_blocks.8.0.emb_layers.1.weight", "input_blocks.8.0.emb_layers.1.bias", "input_blocks.8.0.out_layers.0.weight", "input_blocks.8.0.out_layers.0.bias", "input_blocks.8.0.out_layers.3.weight", "input_blocks.8.0.out_layers.3.bias", "input_blocks.10.0.in_layers.0.weight", "input_blocks.10.0.in_layers.0.bias", "input_blocks.10.0.in_layers.2.weight", "input_blocks.10.0.in_layers.2.bias", "input_blocks.10.0.emb_layers.1.weight", "input_blocks.10.0.emb_layers.1.bias", "input_blocks.10.0.out_layers.0.weight", "input_blocks.10.0.out_layers.0.bias", "input_blocks.10.0.out_layers.3.weight", "input_blocks.10.0.out_layers.3.bias", "output_blocks.1.1.in_layers.0.weight", "output_blocks.1.1.in_layers.0.bias", "output_blocks.1.1.in_layers.2.weight", "output_blocks.1.1.in_layers.2.bias", "output_blocks.1.1.emb_layers.1.weight", "output_blocks.1.1.emb_layers.1.bias", "output_blocks.1.1.out_layers.0.weight", "output_blocks.1.1.out_layers.0.bias", "output_blocks.1.1.out_layers.3.weight", "output_blocks.1.1.out_layers.3.bias", "output_blocks.3.2.in_layers.0.weight", "output_blocks.3.2.in_layers.0.bias", "output_blocks.3.2.in_layers.2.weight", "output_blocks.3.2.in_layers.2.bias", "output_blocks.3.2.emb_layers.1.weight", "output_blocks.3.2.emb_layers.1.bias", "output_blocks.3.2.out_layers.0.weight", "output_blocks.3.2.out_layers.0.bias", "output_blocks.3.2.out_layers.3.weight", "output_blocks.3.2.out_layers.3.bias", "output_blocks.5.1.in_layers.0.weight", "output_blocks.5.1.in_layers.0.bias", "output_blocks.5.1.in_layers.2.weight", "output_blocks.5.1.in_layers.2.bias", "output_blocks.5.1.emb_layers.1.weight", "output_blocks.5.1.emb_layers.1.bias", "output_blocks.5.1.out_layers.0.weight", "output_blocks.5.1.out_layers.0.bias", "output_blocks.5.1.out_layers.3.weight", "output_blocks.5.1.out_layers.3.bias", "output_blocks.7.1.in_layers.0.weight", "output_blocks.7.1.in_layers.0.bias", "output_blocks.7.1.in_layers.2.weight", "output_blocks.7.1.in_layers.2.bias", "output_blocks.7.1.emb_layers.1.weight", "output_blocks.7.1.emb_layers.1.bias", "output_blocks.7.1.out_layers.0.weight", "output_blocks.7.1.out_layers.0.bias", "output_blocks.7.1.out_layers.3.weight", "output_blocks.7.1.out_layers.3.bias", "output_blocks.9.1.in_layers.0.weight", "output_blocks.9.1.in_layers.0.bias", "output_blocks.9.1.in_layers.2.weight", "output_blocks.9.1.in_layers.2.bias", "output_blocks.9.1.emb_layers.1.weight", "output_blocks.9.1.emb_layers.1.bias", "output_blocks.9.1.out_layers.0.weight", "output_blocks.9.1.out_layers.0.bias", "output_blocks.9.1.out_layers.3.weight", "output_blocks.9.1.out_layers.3.bias". 
            Unexpected key(s) in state_dict: "input_blocks.2.0.op.weight", "input_blocks.2.0.op.bias", "input_blocks.4.0.op.weight", "input_blocks.4.0.op.bias", "input_blocks.6.0.op.weight", "input_blocks.6.0.op.bias", "input_blocks.8.0.op.weight", "input_blocks.8.0.op.bias", "input_blocks.10.0.op.weight", "input_blocks.10.0.op.bias", "input_blocks.11.1.norm.weight", "input_blocks.11.1.norm.bias", "input_blocks.11.1.qkv.weight", "input_blocks.11.1.qkv.bias", "input_blocks.11.1.proj_out.weight", "input_blocks.11.1.proj_out.bias", "output_blocks.0.1.norm.weight", "output_blocks.0.1.norm.bias", "output_blocks.0.1.qkv.weight", "output_blocks.0.1.qkv.bias", "output_blocks.0.1.proj_out.weight", "output_blocks.0.1.proj_out.bias", "output_blocks.1.2.conv.weight", "output_blocks.1.2.conv.bias", "output_blocks.1.1.norm.weight", "output_blocks.1.1.norm.bias", "output_blocks.1.1.qkv.weight", "output_blocks.1.1.qkv.bias", "output_blocks.1.1.proj_out.weight", "output_blocks.1.1.proj_out.bias", "output_blocks.3.2.conv.weight", "output_blocks.3.2.conv.bias", "output_blocks.5.1.conv.weight", "output_blocks.5.1.conv.bias", "output_blocks.7.1.conv.weight", "output_blocks.7.1.conv.bias", "output_blocks.9.1.conv.weight", "output_blocks.9.1.conv.bias". 
            size mismatch for out.2.weight: copying a param with shape torch.Size([3, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([6, 128, 3, 3]).
            size mismatch for out.2.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([6]).
    

    May I know how you generated the 256x256 FFHQ.pt and 256x256 AFHQ-dog.pt because I don't fave any issues while loading these weights?

    opened by vinodrajendran001 4
  • Question about your paper train

    Question about your paper train

    Hi, Thanks for sharing the code, I am very interested in your paper. I just wanted to ask a few questions regarding training: Your training is the same as guided-diffusion, only the sampling process is modified, right?

    opened by zhangzhili1112 3
  • Possible errors about using q_sample(x, t)

    Possible errors about using q_sample(x, t)

    image Thanks for your impressive contribution ! You implemented y_{t-1} by using the method q_sample(y, t). However, the right way is q_sample(y, t-1). Is there exisits any errors here? image

    opened by ZGCTroy 2
  • ilvr_sample hangs at the line

    ilvr_sample hangs at the line "model_kwargs = next(data)"

    Hi,

    I am using the FFHQ pt file to load the checkpoint and do the sampling.

    I just placed one example image in the ref_imgs/face/ folder and ran the ilvr_sample file as per the README file. But the script hangs at the line model_kwargs = next(data) .

    May I know how to resolve it?

    Thanks.

    opened by vinodrajendran001 2
  • Training Requirement for GPU

    Training Requirement for GPU

    Excellent work and thanks for sharing your code. I'm a novice at diffusion models, and I'm concerned about theGPU resource this kind of models need. I wonder can all the training be done on a single RTX3090.

    opened by KevinGoodman 2
  • Dataset used for training ffhq_10m.pt

    Dataset used for training ffhq_10m.pt

    Thanks for your amazing work and nice sharing. You've said that 10M images are used to train ffhq_10m.pt in your REAME. As far as I know, there are 70000 images in FFHQ dataset. It is a little confusing. So is there any other data I've missed?

    opened by sunyasheng 2
  • Can this model be applied to other data modalities?

    Can this model be applied to other data modalities?

    Hi, authors. Can this model be applied to other data modalities? such as audio, text,... Have you tried it? Hope that you can give me some suggestions. Thanks in advance!

    opened by ZDstandup 1
  • Questions about style transfer

    Questions about style transfer

    Hi teams!

    Thanks for sharing your code and excellent work at first! I have tried your sampling code on my GPUs and got great outputs. I'm wondering whether your model can produce the outputs, which has the same texture shape of inputs but get colorized as reference pictures style? I have tried big N (e.g.,N=64) as you mentioned in the paper which can preserve only the coarse aspect (e.g., color scheme) of the reference but didn't get a good performance. So could I ask you any ideas to solve this problem?

    (What I have done so far: I used image_train.py and domain A datasets(256x256) to train the model ema_0.9999_010000.pt. And then use some 256x256 images in domain B with big N to generate outputs.)

    Thanks in advance!

    opened by SkrBully 1
  • Confusing about timesteps

    Confusing about timesteps

    Thanks for the fantastic work. In the paper, you said that trained unconditional DDPM with publicly available PyTorch implementation(https://github.com/rosinality/ denoising-diffusion-pytorch) has 1000 timesteps during the training stage. But during the inference stage, directly use 100 steps without respacing. How does it work? Is there a problem with my understanding, and no relevant explanation is found in the paper.

    opened by zacharyclam 1
  • DDIM sampling

    DDIM sampling

    Thanks for the awesome work. Whether I can use ILVR sampling based on the ddim sampling scheme? I try to use ddim sampling scheme with ILVR, but the result will be very blury.

    opened by zacharyclam 1
  • How to sample results on LSUN dataset?

    How to sample results on LSUN dataset?

    Hi, Thanks for your great job on diffusion model! I am wondering if I can download the pre-train model on LSUN from improved-diffusion to sample in your code? I find there are some mismatching when loading the parameters. Have you modified the code base? Thank you.

    opened by XinYu-Andy 1
  • Checkpoints no longer available

    Checkpoints no longer available

    Hi,

    I am trying to download the checkpoints of your model but the link is no longer available, as it says that the file does not exist. I attach here the message from Google Drive.

    Could you provide a new link or the checkpoints?

    image

    Thanks in advance.

    opened by eleGAN23 2
  • Adaption of pre-trained model

    Adaption of pre-trained model

    Hi! Thank you for your interesting repo!

    I have a question about training using pre-trained model of "guided diffusion".

    Is it ok to put other pre-trained model in ./models directory?

    opened by mikio303 0
  • FID test settings

    FID test settings

    It's mentioned that 50K real images are used for FID scores in Table 1. However, there are 70K images in FFHQ and only 1K in METFACES. Could you clarify how the 50K are sampled and or duplicated? Thanks

    opened by zhihongp 1
Owner
Jooyoung Choi
Deep Generative Models
Jooyoung Choi
Pytorch-diffusion - A basic PyTorch implementation of 'Denoising Diffusion Probabilistic Models'

PyTorch implementation of 'Denoising Diffusion Probabilistic Models' This reposi

Arthur Juliani 76 Jan 7, 2023
Denoising Diffusion Probabilistic Models

Denoising Diffusion Probabilistic Models This repo contains code for DDPM training. Based on Denoising Diffusion Probabilistic Models, Improved Denois

Alexander Markov 7 Dec 15, 2022
Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in Pytorch

Retrieval-Augmented Denoising Diffusion Probabilistic Models (wip) Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in P

Phil Wang 55 Jan 1, 2023
A denoising diffusion probabilistic model (DDPM) tailored for conditional generation of protein distograms

Denoising Diffusion Probabilistic Model for Proteins Implementation of Denoising Diffusion Probabilistic Model in Pytorch. It is a new approach to gen

Phil Wang 108 Nov 23, 2022
BDDM: Bilateral Denoising Diffusion Models for Fast and High-Quality Speech Synthesis

Bilateral Denoising Diffusion Models (BDDMs) This is the official PyTorch implementation of the following paper: BDDM: BILATERAL DENOISING DIFFUSION M

null 172 Dec 23, 2022
A PyTorch implementation of the baseline method in Panoptic Narrative Grounding (ICCV 2021 Oral)

A PyTorch implementation of the baseline method in Panoptic Narrative Grounding (ICCV 2021 Oral)

Biomedical Computer Vision @ Uniandes 52 Dec 19, 2022
Official PyTorch implementation for FastDPM, a fast sampling algorithm for diffusion probabilistic models

Official PyTorch implementation for "On Fast Sampling of Diffusion Probabilistic Models". FastDPM generation on CIFAR-10, CelebA, and LSUN datasets. S

Zhifeng Kong 68 Dec 26, 2022
Collapse by Conditioning: Training Class-conditional GANs with Limited Data

Collapse by Conditioning: Training Class-conditional GANs with Limited Data Moha

Mohamad Shahbazi 33 Dec 6, 2022
NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling @ INTERSPEECH 2021 Accepted

NU-Wave — Official PyTorch Implementation NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling Junhyeok Lee, Seungu Han @ MINDsLab Inc

MINDs Lab 242 Dec 23, 2022
PyTorch Implementation of DiffGAN-TTS: High-Fidelity and Efficient Text-to-Speech with Denoising Diffusion GANs

DiffGAN-TTS - PyTorch Implementation PyTorch implementation of DiffGAN-TTS: High

Keon Lee 157 Jan 1, 2023
NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling

NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling For Official repo of NU-Wave: A Diffusion Probabilistic Model for Neural Audio Up

Rishikesh (ऋषिकेश) 38 Oct 11, 2022
Pytorch implementation of "Grad-TTS: A Diffusion Probabilistic Model for Text-to-Speech"

GradTTS Unofficial Pytorch implementation of "Grad-TTS: A Diffusion Probabilistic Model for Text-to-Speech" (arxiv) About this repo This is an unoffic

HeyangXue1997 103 Dec 23, 2022
(ICCV 2021) ProHMR - Probabilistic Modeling for Human Mesh Recovery

ProHMR - Probabilistic Modeling for Human Mesh Recovery Code repository for the paper: Probabilistic Modeling for Human Mesh Recovery Nikos Kolotouros

Nikos Kolotouros 209 Dec 13, 2022
[CVPR 2022 Oral] EPro-PnP: Generalized End-to-End Probabilistic Perspective-n-Points for Monocular Object Pose Estimation

EPro-PnP EPro-PnP: Generalized End-to-End Probabilistic Perspective-n-Points for Monocular Object Pose Estimation In CVPR 2022 (Oral). [paper] Hanshen

 同济大学智能汽车研究所综合感知研究组 ( Comprehensive Perception Research Group under Institute of Intelligent Vehicles, School of Automotive Studies, Tongji University) 842 Jan 4, 2023
The undersampled DWI image using Slice-Interleaved Diffusion Encoding (SIDE) method can be reconstructed by the UNet network.

UNet-SIDE The undersampled DWI image using Slice-Interleaved Diffusion Encoding (SIDE) method can be reconstructed by the UNet network. For Super Reso

TIANTIAN XU 1 Jan 13, 2022
Code for "Human Pose Regression with Residual Log-likelihood Estimation", ICCV 2021 Oral

Human Pose Regression with Residual Log-likelihood Estimation [Paper] [arXiv] [Project Page] Human Pose Regression with Residual Log-likelihood Estima

JeffLi 347 Dec 24, 2022
Improving Contrastive Learning by Visualizing Feature Transformation, ICCV 2021 Oral

Improving Contrastive Learning by Visualizing Feature Transformation This project hosts the codes, models and visualization tools for the paper: Impro

Bingchen Zhao 83 Dec 15, 2022
BARF: Bundle-Adjusting Neural Radiance Fields 🤮 (ICCV 2021 oral)

BARF ?? : Bundle-Adjusting Neural Radiance Fields Chen-Hsuan Lin, Wei-Chiu Ma, Antonio Torralba, and Simon Lucey IEEE International Conference on Comp

Chen-Hsuan Lin 539 Dec 28, 2022