Release for Improved Denoising Diffusion Probabilistic Models

Overview

improved-diffusion

This is the codebase for Improved Denoising Diffusion Probabilistic Models.

Usage

This section of the README walks through how to train and sample from a model.

Installation

Clone this repository and navigate to it in your terminal. Then run:

pip install -e .

This should install the improved_diffusion python package that the scripts depend on.

Preparing Data

The training code reads images from a directory of image files. In the datasets folder, we have provided instructions/scripts for preparing these directories for ImageNet, LSUN bedrooms, and CIFAR-10.

For creating your own dataset, simply dump all of your images into a directory with ".jpg", ".jpeg", or ".png" extensions. If you wish to train a class-conditional model, name the files like "mylabel1_XXX.jpg", "mylabel2_YYY.jpg", etc., so that the data loader knows that "mylabel1" and "mylabel2" are the labels. Subdirectories will automatically be enumerated as well, so the images can be organized into a recursive structure (although the directory names will be ignored, and the underscore prefixes are used as names).

The images will automatically be scaled and center-cropped by the data-loading pipeline. Simply pass --data_dir path/to/images to the training script, and it will take care of the rest.

Training

To train your model, you should first decide some hyperparameters. We will split up our hyperparameters into three groups: model architecture, diffusion process, and training flags. Here are some reasonable defaults for a baseline:

MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3"
DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule linear"
TRAIN_FLAGS="--lr 1e-4 --batch_size 128"

Here are some changes we experiment with, and how to set them in the flags:

  • Learned sigmas: add --learn_sigma True to MODEL_FLAGS
  • Cosine schedule: change --noise_schedule linear to --noise_schedule cosine
  • Reweighted VLB: add --use_kl True to DIFFUSION_FLAGS and add --schedule_sampler loss-second-moment to TRAIN_FLAGS.
  • Class-conditional: add --class_cond True to MODEL_FLAGS.

Once you have setup your hyper-parameters, you can run an experiment like so:

python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS

You may also want to train in a distributed manner. In this case, run the same command with mpiexec:

mpiexec -n $NUM_GPUS python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS

When training in a distributed manner, you must manually divide the --batch_size argument by the number of ranks. In lieu of distributed training, you may use --microbatch 16 (or --microbatch 1 in extreme memory-limited cases) to reduce memory usage.

The logs and saved models will be written to a logging directory determined by the OPENAI_LOGDIR environment variable. If it is not set, then a temporary directory will be created in /tmp.

Sampling

The above training script saves checkpoints to .pt files in the logging directory. These checkpoints will have names like ema_0.9999_200000.pt and model200000.pt. You will likely want to sample from the EMA models, since those produce much better samples.

Once you have a path to your model, you can generate a large batch of samples like so:

python scripts/image_sample.py --model_path /path/to/model.pt $MODEL_FLAGS $DIFFUSION_FLAGS

Again, this will save results to a logging directory. Samples are saved as a large npz file, where arr_0 in the file is a large batch of samples.

Just like for training, you can run image_sample.py through MPI to use multiple GPUs and machines.

You can change the number of sampling steps using the --timestep_respacing argument. For example, --timestep_respacing 250 uses 250 steps to sample. Passing --timestep_respacing ddim250 is similar, but uses the uniform stride from the DDIM paper rather than our stride.

To sample using DDIM, pass --use_ddim True.

Comments
  • FID is instable during training on cifar10

    FID is instable during training on cifar10

    Hi, I am conducting the training on cifar10 and testing the FID (50k samples with 100 timesteps) after trained certain iterations. However, the FID is instable in my experiments. Specifically, it gets 12.00, 7.67, 16.88, 7.78, 6.65, 18.43, 11.51 when trained 80k, 90k, 100k, 150k, 200k, 250k, 300k iterations, respectively (i.e., an up and down pattern). Does anyone find the same issue?

    opened by XinYu-Andy 5
  • vizualization of diffsuion steps every n-timestep

    vizualization of diffsuion steps every n-timestep

    how can i successfully save each n-timesteps of the diffusion sampling step. For so far i edited lines of code in p_sample_loop_progressive(), with something like

    for step, i in enumerate(indices):
     
    
                t = th.tensor([i] * shape[0], device=device)
                
                with th.no_grad():
                    out = self.p_sample(
                        model,
                        img,
                        t,
                        clip_denoised=clip_denoised,
                        denoised_fn=denoised_fn,
                        model_kwargs=model_kwargs,
                    )
                    yield out
                    if step % 1000 == 0:
                        print('step',step)
    
                        sample = out["sample"]
                        sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
                        sample = sample.permute(0, 2, 3, 1)
                        sample = sample.contiguous()
    
                        image.extend([sample.cpu().numpy() for sample in [sample]])
               
                        arr = np.concatenate(image, axis=0)
    
                        arr = arr[: shape[0]]
       
                        img = Image.fromarray(arr[0])
    
                        out_path = os.path.join(logger.get_dir(), f"samples_{label}_{str(step).zfill(4)}.npz")
                        out_image = os.path.join(logger.get_dir(), f"samples_{label}_{str(step).zfill(4)}.tif")
                        img.save(out_image, compression='raw')
                        np.savez(out_path, arr)
                    img = out["sample"]
    

    but i always get the fully noisy image.

    Many thanks in advance

    opened by choROPeNt 2
  • Why do you need fp16 utils?

    Why do you need fp16 utils?

    Hi! What's the reason y'all are using explicit utilities to convert float32 to float16 tensors instead of using torch.amp? I'm thinking of building out my own implementation of this repository, using einops, and was curious if this is something I'd need to take into account instead of just using torch.amp.

    opened by vedantroy 2
  • The error when using pretrained checkpoint.

    The error when using pretrained checkpoint.

    When I use the checkpoint from the part "Upsampling 256x256 model (280M parameters, trained for 500K iterations) ", the program will report an error. Do you know why this happens?I used the recommended model parameters.

    RuntimeError: Error(s) in loading state_dict for UNetModel: Missing key(s) in state_dict: "input_blocks.4.0.skip_connection.weight", "input_blocks.4.0.skip_connection.bias", "input_blocks.7.1.norm.weight", "input_blocks.7.1.norm.bias", "input_blocks.7.1.qkv.weight", "input_blocks.7.1.qkv.bias", "input_blocks.7.1.proj_out.weight", "input_blocks.7.1.proj_out.bias", "input_blocks.8.1.norm.weight", "input_blocks.8.1.norm.bias", "input_blocks.8.1.qkv.weight", "input_blocks.8.1.qkv.bias", "input_blocks.8.1.proj_out.weight", "input_blocks.8.1.proj_out.bias", "input_blocks.10.0.skip_connection.weight", "input_blocks.10.0.skip_connection.bias", "input_blocks.10.1.norm.weight", "input_blocks.10.1.norm.bias", "input_blocks.10.1.qkv.weight", "input_blocks.10.1.qkv.bias", "input_blocks.10.1.proj_out.weight", "input_blocks.10.1.proj_out.bias", "input_blocks.11.1.norm.weight", "input_blocks.11.1.norm.bias", "input_blocks.11.1.qkv.weight", "input_blocks.11.1.qkv.bias", "input_blocks.11.1.proj_out.weight", "input_blocks.11.1.proj_out.bias". Unexpected key(s) in state_dict: "input_blocks.12.0.op.weight", "input_blocks.12.0.op.bias", "input_blocks.13.0.in_layers.0.weight", "input_blocks.13.0.in_layers.0.bias", "input_blocks.13.0.in_layers.2.weight", "input_blocks.13.0.in_layers.2.bias", "input_blocks.13.0.emb_layers.1.weight", "input_blocks.13.0.emb_layers.1.bias", "input_blocks.13.0.out_layers.0.weight", "input_blocks.13.0.out_layers.0.bias", "input_blocks.13.0.out_layers.3.weight", "input_blocks.13.0.out_layers.3.bias", "input_blocks.13.0.skip_connection.weight", "input_blocks.13.0.skip_connection.bias", "input_blocks.13.1.norm.weight", "input_blocks.13.1.norm.bias", "input_blocks.13.1.qkv.weight", "input_blocks.13.1.qkv.bias", "input_blocks.13.1.proj_out.weight", "input_blocks.13.1.proj_out.bias", "input_blocks.14.0.in_layers.0.weight", "input_blocks.14.0.in_layers.0.bias", "input_blocks.14.0.in_layers.2.weight", "input_blocks.14.0.in_layers.2.bias", "input_blocks.14.0.emb_layers.1.weight", "input_blocks.14.0.emb_layers.1.bias", "input_blocks.14.0.out_layers.0.weight", "input_blocks.14.0.out_layers.0.bias", "input_blocks.14.0.out_layers.3.weight", "input_blocks.14.0.out_layers.3.bias", "input_blocks.14.1.norm.weight", "input_blocks.14.1.norm.bias", "input_blocks.14.1.qkv.weight", "input_blocks.14.1.qkv.bias", "input_blocks.14.1.proj_out.weight", "input_blocks.14.1.proj_out.bias", "input_blocks.15.0.op.weight", "input_blocks.15.0.op.bias", "input_blocks.16.0.in_layers.0.weight", "input_blocks.16.0.in_layers.0.bias", "input_blocks.16.0.in_layers.2.weight", "input_blocks.16.0.in_layers.2.bias", "input_blocks.16.0.emb_layers.1.weight", "input_blocks.16.0.emb_layers.1.bias", "input_blocks.16.0.out_layers.0.weight", "input_blocks.16.0.out_layers.0.bias", "input_blocks.16.0.out_layers.3.weight",

    opened by ZY123-GOOD 2
  • Model checkpoints from Table 3

    Model checkpoints from Table 3

    Hi,

    Thanks for releasing the code, especially in PyTorch. I was wondering if you are also planning to release the checkpoints for models from Table 3?

    opened by VSehwag 2
  • Run time error when trying to load a model for sampling

    Run time error when trying to load a model for sampling

    Hello,

    I can't manage to sample from a model. When I am trying to do so : RuntimeError: Error(s) in loading state_dict for UNetModel: Missing key(s) in state_dict: XXX size mismatch for: XXXX

    This happens when I try to load the model00000.pt

    Do you have an idea how to solve this issue ?

    opened by NicolasNerr 1
  • Issue in the scaling term of dot product attention

    Issue in the scaling term of dot product attention

    I noticed that in the dot product attention implementation here (the QKVAttention class) the scale is $1 / \sqrt[\leftroot{-2}\uproot{2}4]{d}$ instead of $1 / \sqrt{d}$. It is here in the codebase: https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/unet.py#L247

    It is a minor issue and probably does not matter in the overall performance of models, but thought it's worth pointing out.

    opened by saeidnp 1
  • Classifier guidance or classifer-free guidance?

    Classifier guidance or classifer-free guidance?

    Hi, thanks for your great work!

    Recently I have been working on the conditional generation of diffusion models, and I found that it has classifier guidance and classifier-free guidance. For the former, a classifier needs to be pre-trained. But I didn't find this pre-trained classifier in your code. I am a little confused if you are using the classifier-free guidance.

    opened by MMMMMz 1
  • broadcast_buffers - ignore

    broadcast_buffers - ignore

    After trying out the system on CPU I'm trying to set it up using the GPU and I get the following error:

    TypeError: init() got an unexpected keyword argument 'broadcast_buffers'

    Tried it on two different machines and got the same result. Anyone came across this?

    opened by Shadeenu 1
  • What is the expected scale of the RESCALED_MSE loss?

    What is the expected scale of the RESCALED_MSE loss?

    I’m working on recreating your results and I’m able to train with a bare MSE loss just fine. However, my implementation of the RESCALED_MSE loss (mse for the means, vlb for the variances) gives a loss orders of magnitude larger than the bare MSE loss. This seems wrong, since you have an explicit 1/1000 scaling factor on the vlb loss, for the purpose of not overwhelming the MSE loss.

    What is the expected scale of RESCALED_MSE loss?

    opened by LucasSloan 1
  • Training Issue

    Training Issue

    I tried to train a DDPM to produce cat images. But it seems like the training process does not work properly. Can anyone tell me what happened? No error is reported.

    I am using Wsl2.

    Here is the code: train_ddpm.sh #!/bin/bash MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3" DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine" TRAIN_FLAGS="--lr 1e-4 --batch_size 128" python scripts/image_train.py --data_dir cats/ $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS

    console outputs: (Generator) root@LAPTOP-I6AJJ63E:/mnt/d/CatPicGenerateDDPM/improved-diffusion# ./train_ddpm.sh Logging to /tmp/openai-2022-07-20-21-25-58-111937 creating model and diffusion... creating data loader... training... ./train_ddpm.sh: line 5: 455 Killed python scripts/image_train.py --data_dir cats/ $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS

    Thank you very much!

    opened by hym97 1
  • Slow Sampling

    Slow Sampling

    I have noticed that sampling with this model takes an extraordinary amount of time, far slower than any individual training loop. At 4000 diffusion steps, it seems to get through around 100-500 epochs before it can generate a single image. Is this normal behavior, and if not are there any fixes to this? Thank you for all of the help.

    opened by retepseamus 0
  • p_mean_variance mean calculation

    p_mean_variance mean calculation

    I was looking through the code to see how the paper was implemented, but I ran into an issue when looking at the part of the paper measuring the KL loss between two Gaussians: image

    Specifically, the Loss at time t-1 is the KL loss between the predicted gaussian and real gaussian at time t-1. The predicted gaussian is defined as follows: image image image

    And the real gaussian is defined as follows: image

    The formulation of the loss function makes sense to me, but when I look at the code, it looks like the authors are having the model predict mu_tilde (eq 11) as opposed to mu (eq 13). I'm looking at the following function in the code: https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L232

    In this function, the mean is calculated from epsilon by first calculating the prediction for x_0, then calculating the mean at time t. image

    To predict x_0, the following function is used: https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L328

    But, this function looks to be the formulation for the mean function (eq 13)

    I have a couple of questions regarding the implementation:

    1. Why is the mean function (eq 11) for the real gaussian distribution (eq 12) being used when retrieving the value of the predicted gaussian distribution (eq 3) when the formulation for the predicted gaussian distribution is formulated as a function of eq 13?
    2. Why is x_0/x_start being calculated directly from eq 13, the predicted mean?

    Thanks for the help!

    opened by gmongaras 0
  • minimal number of samples for stable training -> some issue

    minimal number of samples for stable training -> some issue

    Hi !

    I was wondering if some people tested the diffusion training process with few images ( 1000 or less) and obtained good results ?

    I am working with rare pathology images, and I have only 300 of them. I am seeing some unexpected behavior on the generated samples (lack of diversity, color shifts etc...) As far as I am aware, diffusion models work better with low training data points than GANs ?

    Thank you

    opened by NicolasNerr 0
  • Strange color shifts for custom data (histopathology images)

    Strange color shifts for custom data (histopathology images)

    Hello,

    I am using the repo to produce synthetic histopathology images.

    However, I am observing a weird color shift during training. As training progress, this shift is getting worst and worst. Basically, the images I am working with are all kind of pinkish, but produced samples are sometimes in this range, but can also be yellow, green, red, etc...

    Is this a problem that has been observed before with diffusion models ? If yes, is there a know solution or an explanation for this behavior ?

    Thank you so much !

    opened by NicolasNerr 2
  • The meaning of the scaling factor 0.5 in sampling from p(xt-1|xt)

    The meaning of the scaling factor 0.5 in sampling from p(xt-1|xt)

    I noticed a scaling factor of 0.5 when sampling from p(xt-1|xt). I tried to find the definition of this 0.5 on the paper but failed. Is there any special factor behind this 0.5? would the performance be different without this factor?

    https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L386

    opened by zihaozou 0
Owner
OpenAI
OpenAI
Improved version calculator, now using while True and etc

CalcuPython_2.0 Olá! Calculadora versão melhorada, agora usando while True e etc... melhorei o design e os carai tudo (rode no terminal, pra melhor ex

Scott 2 Jan 27, 2022
Python / C++ based particle reaction-diffusion simulator

ReaDDy (Reaction Diffusion Dynamics) is an open source particle based reaction-diffusion simulator that can be configured and run via Python. Currentl

ReaDDy 46 Dec 9, 2022
A python package to adjust the bias of probabilistic forecasts/hindcasts using "Mean and Variance Adjustment" method.

Documentation A python package to adjust the bias of probabilistic forecasts/hindcasts using "Mean and Variance Adjustment" method. Read documentation

null 1 Feb 2, 2022
Decoupled Smoothing in Probabilistic Soft Logic

Decoupled Smoothing in Probabilistic Soft Logic Experiments for "Decoupled Smoothing in Probabilistic Soft Logic". Probabilistic Soft Logic Probabilis

Kushal Shingote 1 Feb 8, 2022
WhyNotWin11 - Detection Script to help identify why your PC isn't Windows 11 Release Ready

WhyNotWin11 - Detection Script to help identify why your PC isn't Windows 11 Release Ready

Robert C. Maehl 5.9k Dec 31, 2022
personal dotfiles for rolling release linux distros

dotfiles Screenshots: Directions: Deploy my dotfiles with yadm Packages from arch listed in .installed-packages Information on osu! see ~/Games/osu!/.

-pacer- 0 Sep 18, 2022
Meaningful and minimalist release notes for developers

Managing manual release notes is hard. Therefore, everyone tends to generate release notes from commit messages. But, you won't get a meaningful release note at the end.

codezri 31 Dec 30, 2022
Creates a release pull request updating changelog and tags with standard-version

standard version release branch Github action to open releases following convent

null 8 Sep 13, 2022
Viewflow is an Airflow-based framework that allows data scientists to create data models without writing Airflow code.

Viewflow Viewflow is a framework built on the top of Airflow that enables data scientists to create materialized views. It allows data scientists to f

DataCamp 114 Oct 12, 2022
🪄 Auto-generate Streamlit UI from Pydantic Models and Dataclasses.

Streamlit Pydantic Auto-generate Streamlit UI elements from Pydantic models. Getting Started • Documentation • Support • Report a Bug • Contribution •

Lukas Masuch 103 Dec 25, 2022
PyPI package for scaffolding out code for decision tree models that can learn to find relationships between the attributes of an object.

Decision Tree Writer This package allows you to train a binary classification decision tree on a list of labeled dictionaries or class instances, and

null 2 Apr 23, 2022
A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

BoMb-OT Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via

Khai Ba Nguyen 18 Nov 14, 2022
Utils to quickly evaluate many 🤗 models on the GLUE tasks

Utils to quickly evaluate many ?? models on the GLUE tasks

Przemyslaw K. Joniak 1 Dec 22, 2021
A python package to manage the stored receiver-side Strain Green's Tensor (SGT) database of 3D background models and able to generate Green's function and synthetic waveform

A python package to manage the stored receiver-side Strain Green's Tensor (SGT) database of 3D background models and able to generate Green's function and synthetic waveform

Liang Ding 7 Dec 14, 2022
This is a Blender 2.9 script for importing mixamo Models to Godot-3

Mixamo-To-Godot This is a Blender 2.9 script for importing mixamo Models to Godot-3 The script does the following things Imports the mixamo models fro

null 8 Sep 2, 2022
CALPHAD tools for designing thermodynamic models, calculating phase diagrams and investigating phase equilibria.

CALPHAD tools for designing thermodynamic models, calculating phase diagrams and investigating phase equilibria.

pycalphad 189 Dec 13, 2022
TrainingBike - Code, models and schematics I've used to interface my stationary training bike with PC.

TrainingBike Code, models and schematics I've used to interface my stationary training bike with PC. You can find more information about the project i

null 1 Jan 1, 2022
Blender 3.1 Alpha (and later) PLY importer that correctly loads point clouds (and all PLY models as point clouds)

import-ply-as-verts Blender 3.1 Alpha (and later) PLY importer that correctly loads point clouds (and all PLY models as point clouds) Latest News Mand

Michael Prostka 82 Dec 20, 2022
Minimal diffusion models - Minimal code and simple experiments to play with Denoising Diffusion Probabilistic Models (DDPMs)

Minimal code and simple experiments to play with Denoising Diffusion Probabilist

Rithesh Kumar 16 Oct 6, 2022
Pytorch-diffusion - A basic PyTorch implementation of 'Denoising Diffusion Probabilistic Models'

PyTorch implementation of 'Denoising Diffusion Probabilistic Models' This reposi

Arthur Juliani 76 Jan 7, 2023