PyTorch implementation for Score-Based Generative Modeling through Stochastic Differential Equations (ICLR 2021, Oral)

Overview

Score-Based Generative Modeling through Stochastic Differential Equations

PWC

This repo contains a PyTorch implementation for the paper Score-Based Generative Modeling through Stochastic Differential Equations

by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole


We propose a unified framework that generalizes and improves previous work on score-based generative models through the lens of stochastic differential equations (SDEs). In particular, we can transform data to a simple noise distribution with a continuous-time stochastic process described by an SDE. This SDE can be reversed for sample generation if we know the score of the marginal distributions at each intermediate time step, which can be estimated with score matching. The basic idea is captured in the figure below:

schematic

Our work enables a better understanding of existing approaches, new sampling algorithms, exact likelihood computation, uniquely identifiable encoding, latent code manipulation, and brings new conditional generation abilities (including but not limited to class-conditional generation, inpainting and colorization) to the family of score-based generative models.

All combined, we achieved an FID of 2.20 and an Inception score of 9.89 for unconditional generation on CIFAR-10, as well as high-fidelity generation of 1024px Celeba-HQ images (samples below). In addition, we obtained a likelihood value of 2.99 bits/dim on uniformly dequantized CIFAR-10 images.

FFHQ samples

What does this code do?

Aside from the NCSN++ and DDPM++ models in our paper, this codebase also re-implements many previous score-based models in one place, including NCSN from Generative Modeling by Estimating Gradients of the Data Distribution, NCSNv2 from Improved Techniques for Training Score-Based Generative Models, and DDPM from Denoising Diffusion Probabilistic Models.

It supports training new models, evaluating the sample quality and likelihoods of existing models. We carefully designed the code to be modular and easily extensible to new SDEs, predictors, or correctors.

JAX version

Please find a JAX implementation here, which additionally supports class-conditional generation with a pre-trained classifier, and resuming an evalution process after pre-emption.

JAX vs. PyTorch

In general, this PyTorch version consumes less memory but runs slower than JAX. Here is a benchmark on training an NCSN++ cont. model with VE SDE. Hardware is 4x Nvidia Tesla V100 GPUs (32GB)

Framework Time (second per step) Memory usage in total (GB)
PyTorch 0.56 20.6
JAX (n_jitted_steps=1) 0.30 29.7
JAX (n_jitted_steps=5) 0.20 74.8

How to run the code

Dependencies

Run the following to install a subset of necessary python packages for our code

pip install -r requirements.txt

Stats files for quantitative evaluation

We provide the stats file for CIFAR-10. You can download cifar10_stats.npz and save it to assets/stats/. Check out #5 on how to compute this stats file for new datasets.

Usage

Train and evaluate our models through main.py.

main.py:
  --config: Training configuration.
    (default: 'None')
  --eval_folder: The folder name for storing evaluation results
    (default: 'eval')
  --mode: <train|eval>: Running mode: train or eval
  --workdir: Working directory
  • config is the path to the config file. Our prescribed config files are provided in configs/. They are formatted according to ml_collections and should be quite self-explanatory.

    Naming conventions of config files: the path of a config file is a combination of the following dimensions:

    • dataset: One of cifar10, celeba, celebahq, celebahq_256, ffhq_256, celebahq, ffhq.
    • model: One of ncsn, ncsnv2, ncsnpp, ddpm, ddpmpp.
    • continuous: train the model with continuously sampled time steps.
  • workdir is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results.

  • eval_folder is the name of a subfolder in workdir that stores all artifacts of the evaluation process, like meta checkpoints for pre-emption prevention, image samples, and numpy dumps of quantitative results.

  • mode is either "train" or "eval". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in workdir/checkpoints-meta . When set to "eval", it can do an arbitrary combination of the following

    • Evaluate the loss function on the test / validation dataset.

    • Generate a fixed number of samples and compute its Inception score, FID, or KID. Prior to evaluation, stats files must have already been downloaded/computed and stored in assets/stats.

    • Compute the log-likelihood on the training or test dataset.

    These functionalities can be configured through config files, or more conveniently, through the command-line support of the ml_collections package. For example, to generate samples and evaluate sample quality, supply the --config.eval.enable_sampling flag; to compute log-likelihoods, supply the --config.eval.enable_bpd flag, and specify --config.eval.dataset=train/test to indicate whether to compute the likelihoods on the training or test dataset.

How to extend the code

  • New SDEs: inherent the sde_lib.SDE abstract class and implement all abstract methods. The discretize() method is optional and the default is Euler-Maruyama discretization. Existing sampling methods and likelihood computation will automatically work for this new SDE.
  • New predictors: inherent the sampling.Predictor abstract class, implement the update_fn abstract method, and register its name with @register_predictor. The new predictor can be directly used in sampling.get_pc_sampler for Predictor-Corrector sampling, and all other controllable generation methods in controllable_generation.py.
  • New correctors: inherent the sampling.Corrector abstract class, implement the update_fn abstract method, and register its name with @register_corrector. The new corrector can be directly used in sampling.get_pc_sampler, and all other controllable generation methods in controllable_generation.py.

Pretrained checkpoints

All checkpoints are provided in this Google drive.

Instructions: You may find two checkpoints for some models. The first checkpoint (with a smaller number) is the one that we reported FID scores in our paper's Table 3 (also corresponding to the FID and IS columns in the table below). The second checkpoint (with a larger number) is the one that we reported likelihood values and FIDs of black-box ODE samplers in our paper's Table 2 (also FID(ODE) and NNL (bits/dim) columns in the table below). The former corresponds to the smallest FID during the course of training (every 50k iterations). The later is the last checkpoint during training.

Per Google's policy, we cannot release our original CelebA and CelebA-HQ checkpoints. That said, I have re-trained models on FFHQ 1024px, FFHQ 256px and CelebA-HQ 256px with personal resources, and they achieved similar performance to our internal checkpoints.

Here is a detailed list of checkpoints and their results reported in the paper. FID (ODE) corresponds to the sample quality of black-box ODE solver applied to the probability flow ODE.

Checkpoint path FID IS FID (ODE) NNL (bits/dim)
ve/cifar10_ncsnpp/ 2.45 9.73 - -
ve/cifar10_ncsnpp_continuous/ 2.38 9.83 - -
ve/cifar10_ncsnpp_deep_continuous/ 2.20 9.89 - -
vp/cifar10_ddpm/ 3.24 - 3.37 3.28
vp/cifar10_ddpm_continuous - - 3.69 3.21
vp/cifar10_ddpmpp 2.78 9.64 - -
vp/cifar10_ddpmpp_continuous 2.55 9.58 3.93 3.16
vp/cifar10_ddpmpp_deep_continuous 2.41 9.68 3.08 3.13
subvp/cifar10_ddpm_continuous - - 3.56 3.05
subvp/cifar10_ddpmpp_continuous 2.61 9.56 3.16 3.02
subvp/cifar10_ddpmpp_deep_continuous 2.41 9.57 2.92 2.99
Checkpoint path Samples
ve/bedroom_ncsnpp_continuous bedroom_samples
ve/church_ncsnpp_continuous church_samples
ve/ffhq_1024_ncsnpp_continuous ffhq_1024
ve/ffhq_256_ncsnpp_continuous ffhq_256_samples
ve/celebahq_256_ncsnpp_continuous celebahq_256_samples

Demonstrations and tutorials

Link Description
Open In Colab Load our pretrained checkpoints and play with sampling, likelihood computation, and controllable synthesis (JAX + FLAX)
Open In Colab Load our pretrained checkpoints and play with sampling, likelihood computation, and controllable synthesis (PyTorch)
Open In Colab Tutorial of score-based generative models in JAX + FLAX
Open In Colab Tutorial of score-based generative models in PyTorch

Tips

  • When using the JAX codebase, you can jit multiple training steps together to improve training speed at the cost of more memory usage. This can be set via config.training.n_jitted_steps. For CIFAR-10, we recommend using config.training.n_jitted_steps=5 when your GPU/TPU has sufficient memory; otherwise we recommend using config.training.n_jitted_steps=1. Our current implementation requires config.training.log_freq to be dividable by n_jitted_steps for logging and checkpointing to work normally.
  • The snr (signal-to-noise ratio) parameter of LangevinCorrector somewhat behaves like a temperature parameter. Larger snr typically results in smoother samples, while smaller snr gives more diverse but lower quality samples. Typical values of snr is 0.05 - 0.2, and it requires tuning to strike the sweet spot.
  • For VE SDEs, we recommend choosing config.model.sigma_max to be the maximum pairwise distance between data samples in the training dataset.

References

If you find the code useful for your research, please consider citing

@inproceedings{
  song2021scorebased,
  title={Score-Based Generative Modeling through Stochastic Differential Equations},
  author={Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},
  booktitle={International Conference on Learning Representations},
  year={2021},
  url={https://openreview.net/forum?id=PxTIG12RRHS}
}

This work is built upon some previous papers which might also interest you:

  • Song, Yang, and Stefano Ermon. "Generative Modeling by Estimating Gradients of the Data Distribution." Proceedings of the 33rd Annual Conference on Neural Information Processing Systems. 2019.
  • Song, Yang, and Stefano Ermon. "Improved techniques for training score-based generative models." Proceedings of the 34th Annual Conference on Neural Information Processing Systems. 2020.
  • Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Proceedings of the 34th Annual Conference on Neural Information Processing Systems. 2020.
Comments
  • How to train a model with 16GB GPU

    How to train a model with 16GB GPU

    Hey,

    thanks for your PyTorch implementation. I am trying to train a model with my custom dataset. I managed to set the dataset (tfrecords) up but I run out of memory on training loop step 0.

    RuntimeError: CUDA out of memory. Tried to allocate 128.00 MiB (GPU 0; 15.90 GiB total capacity; 14.61 GiB already allocated; 53.75 MiB free; 14.84 GiB reserved in total by PyTorch)
    

    Sadly, I do not have more GPU RAM options. My config is the following:

    from configs.default_lsun_configs import get_default_configs
    
    
    def get_config():
      config = get_default_configs()
      # training
      training = config.training
      training.sde = 'vesde'
      training.continuous = True
    
      # sampling
      sampling = config.sampling
      sampling.method = 'pc'
      sampling.predictor = 'reverse_diffusion'
      sampling.corrector = 'langevin'
    
      # data
      data = config.data
      data.dataset = 'CUSTOM'
      data.image_size = 128
      data.tfrecords_path = '/content/drive/MyDrive/Training/tf_dataset'
    
    
      # model
      model = config.model
      model.name = 'ncsnpp'
      model.sigma_max = 217
      model.scale_by_sigma = True
      model.ema_rate = 0.999
      model.normalization = 'GroupNorm'
      model.nonlinearity = 'swish'
      model.nf = 128
      model.ch_mult = (1, 1, 2, 2, 2, 2, 2)
      model.num_res_blocks = 2
      model.attn_resolutions = (16,)
      model.resamp_with_conv = True
      model.conditional = True
      model.fir = True
      model.fir_kernel = [1, 3, 3, 1]
      model.skip_rescale = True
      model.resblock_type = 'biggan'
      model.progressive = 'output_skip'
      model.progressive_input = 'input_skip'
      model.progressive_combine = 'sum'
      model.attention_type = 'ddpm'
      model.init_scale = 0.
      model.fourier_scale = 16
      model.conv_size = 3
    
      return config
    

    Are there any options to improve memory efficiency? I would like to stay at a 128x128 resolution (if it is possible).

    Thanks!

    opened by pbizimis 7
  • Accelerat the Training Process

    Accelerat the Training Process

    Hi, yang song, thanks for your nice work.

    I tried to reproduce the experiment "configs/subvp/cifar10_ncsnpp_continuous.py", which runs on a single V100 with 128 images. However, I found the training is too slow, as of now, 100K iterations consumed around 23 hours.

    I want to ask if an experiment with a larger batch size run on multiple GPU can produce the same performance? At your convenience, would you share with me the config of the multiple GPU experiment of cifar10?

    Sincerely thanks for your help.

    opened by JiYuanFeng 2
  • Round operation for discrete models

    Round operation for discrete models

    Hello,

    Firstly, congratulations on the amazing work. The ICLR award was well deserved!

    I don't want to be pedantic but I realized that the get_score_fn for discrete models doesn't have a torch.round() operation even though the t at training time is an int. Therefore, the sampling is being done with slightly different values than the training (e.g. 500.1 instead of 500). I'm not sure if this really affects performance, it's just an observation.

    I would add labels = torch.round(labels) after line 155 of the models/utils.py file.

    Many thanks, Pedro

    opened by SANCHES-Pedro 2
  • subVPSDE sample

    subVPSDE sample

    Hi,

    When I use Score_SDE_demo_PyTorch.ipynb and set score-based model to subVPSDE, it shows that "AttributeError: 'subVPSDE' object has no attribute 'alphas'" in subsection "PC sampling", "PC inpainting", "PC colorizer".

    opened by TLi347 1
  • Latent Code Manipulation

    Latent Code Manipulation

    Hi, can someone tell me where is the code for "manipulation of latent representation". Like how did you where did you use interpolation and temperature change stuff Thank you

    opened by agSidharth 1
  • Add link to diffusers

    Add link to diffusers

    Hey :wave: from the diffusers team,

    Just wanted to ask if you are interested in adding a link to the diffusers library to your README. We're actively maintaining the ScoreSdeVE pipeline and also are planning to integrate your models with faster schedulers from karras et al.

    When integrating the score_sde models we made sure to match the output scores 1-to-1 - do you think such an addition could be useful for the readers of your repo? :-)

    Thanks a mille for open-sourcing your method - it has been super helpful to better understand your paper and score-based diffusion models in general!

    opened by patrickvonplaten 0
  • Added sliced score matching; does not learn (DO NOT MERGE)

    Added sliced score matching; does not learn (DO NOT MERGE)

    Hello,

    I'm wondering if you have any idea why this does not work? I have updated get_sde_loss_fn to do sliced score matching as you've done in your sibling repo (https://github.com/ermongroup/sliced_score_matching). The eventual goal is to implement numerical sampling for an SDE without a closed form perturbation kernel.

    The best I can obtain are patterns that look like this: download

    This is CIFAR data set. I know 2000 iterations is substantially shorter than is required to get "good" results, but when I use get_sde_loss_fn as it exists in your code, I get acceptable results.

    image

    Do you have any idea where I could be going wrong? Thanks in advance

    opened by KyleM4t1qbit 0
  • The small experiment in Figure 2

    The small experiment in Figure 2

    Hello, I want to realize the conversion of a one-dimensional non-normal distribution data into a standard distribution in Figure 2 of the paper. Could you give me some tips?

    opened by WN1695173791 0
  • Checkpoint

    Checkpoint

    There is a "for" circle in your code with relate to checkpoints, however I haven't discover any consecutively numbered checkpoints. Can I replace them with one checkpoint if it does not have much influence on outputs.

    Codes in run_lib: "for ckpt in range(begin_ckpt, config.eval.end_ckpt + 1)" (There is no such consecutive numbered checkpoint in presented URL).

    Error for example: Waiting for check_point 9 (I can not download check_point 9 anywhere).

    opened by FT1021 0
  • ODE sampling get unrealistic images

    ODE sampling get unrealistic images

    Thanks for your great work which inspires so many!!

    As mentioned in #13, I also noticed that the pure ODE sampling results is not satisfying for 256*256 images (maybe also bigger size).

    Most of the time these generated images are blurry or over-smooth, sometimes even very noisy.

    To reproduce, one can simply run the demo notebook with pretrained checkpoints provided by the authors.

    opened by yuanzhi-zhu 1
  • Run the code in single GPU

    Run the code in single GPU

    Dear Song,

    Hi, thanks for your great work. I try to reproduce your work to enhance it a bit, but there is some problem in my setting. I have to use a single GPU due to the limitation of resources. For this, I add os.environ["CUDA_VISIBLE_DEVICES"]='1' in main.py or set device with torch.device('cuda:1') in run_lib.train.

    However, it always assigns gpu:0, so RuntimeError: cuDNN error: CUDNN_STATUS_MAPPING_ERROR is occurred.

    Is there any solution for it?

    GPU that I used is NVIDIA RTX3090 with 24GB. Thank you.

    opened by shhh0620 0
  • Error in the denoising score matching loss?

    Error in the denoising score matching loss?

    I may be wrong, but it appears to me that there is an error in the denoising score matching loss:

    https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/losses.py#L95

    In particular, std, the denominator for z, should be squared, shouldn't it? This is what is prescribed in Eq. 5 in the earlier work [1]. Does this seem to be the case?

    [1] Y. Song, S. Ermon, Generative Modeling by Estimating Gradients of the Data Distribution, 2019.

    opened by lihenryhfl 1
Owner
Yang Song
PhD Candidate in Stanford AI Lab
Yang Song
Python framework for Stochastic Differential Equations modeling

SDElearn: a Python package for SDE modeling This package implements functionalities for working with Stochastic Differential Equations models (SDEs fo

null 4 May 10, 2022
PyTorch implementation for SDEdit: Image Synthesis and Editing with Stochastic Differential Equations

SDEdit: Image Synthesis and Editing with Stochastic Differential Equations Project | Paper | Colab PyTorch implementation of SDEdit: Image Synthesis a

null 536 Jan 5, 2023
Code for "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations"

Infinitely Deep Bayesian Neural Networks with SDEs This library contains JAX and Pytorch implementations of neural ODEs and Bayesian layers for stocha

Winnie Xu 95 Nov 26, 2021
Partial implementation of ODE-GAN technique from the paper Training Generative Adversarial Networks by Solving Ordinary Differential Equations

ODE GAN (Prototype) in PyTorch Partial implementation of ODE-GAN technique from the paper Training Generative Adversarial Networks by Solving Ordinary

Somshubra Majumdar 15 Feb 10, 2022
Leibniz is a python package which provide facilities to express learnable partial differential equations with PyTorch

Leibniz is a python package which provide facilities to express learnable partial differential equations with PyTorch

Beijing ColorfulClouds Technology Co.,Ltd. 16 Aug 7, 2022
Supplementary code for the paper "Meta-Solver for Neural Ordinary Differential Equations" https://arxiv.org/abs/2103.08561

Meta-Solver for Neural Ordinary Differential Equations Towards robust neural ODEs using parametrized solvers. Main idea Each Runge-Kutta (RK) solver w

Julia Gusak 25 Aug 12, 2021
Deep learning library for solving differential equations and more

DeepXDE Voting on whether we should have a Slack channel for discussion. DeepXDE is a library for scientific machine learning. Use DeepXDE if you need

Lu Lu 1.4k Dec 29, 2022
Official codebase for ICLR oral paper Unsupervised Vision-Language Grammar Induction with Shared Structure Modeling

CLIORA This is the official codebase for ICLR oral paper: Unsupervised Vision-Language Grammar Induction with Shared Structure Modeling. We introduce

Bo Wan                                             32 Dec 23, 2022
Based on the paper "Geometry-aware Instance-reweighted Adversarial Training" ICLR 2021 oral

Geometry-aware Instance-reweighted Adversarial Training This repository provides codes for Geometry-aware Instance-reweighted Adversarial Training (ht

Jingfeng 47 Dec 22, 2022
A PyTorch implementation of ICLR 2022 Oral paper PiCO: Contrastive Label Disambiguation for Partial Label Learning

PiCO: Contrastive Label Disambiguation for Partial Label Learning This is a PyTorch implementation of ICLR 2022 Oral paper PiCO; also see our Project

王皓波 83 May 11, 2022
Generative Autoregressive, Normalized Flows, VAEs, Score-based models (GANVAS)

GANVAS-models This is an implementation of various generative models. It contains implementations of the following: Autoregressive Models: PixelCNN, G

MRSAIL (Mini Robotics, Software & AI Lab) 6 Nov 26, 2022
A Pytorch implementation of "Manifold Matching via Deep Metric Learning for Generative Modeling" (ICCV 2021)

Manifold Matching via Deep Metric Learning for Generative Modeling A Pytorch implementation of "Manifold Matching via Deep Metric Learning for Generat

null 69 Dec 10, 2022
[ICLR 2021 Spotlight Oral] "Undistillable: Making A Nasty Teacher That CANNOT teach students", Haoyu Ma, Tianlong Chen, Ting-Kuei Hu, Chenyu You, Xiaohui Xie, Zhangyang Wang

Undistillable: Making A Nasty Teacher That CANNOT teach students "Undistillable: Making A Nasty Teacher That CANNOT teach students" Haoyu Ma, Tianlong

VITA 71 Dec 28, 2022
[ICLR 2021, Spotlight] Large Scale Image Completion via Co-Modulated Generative Adversarial Networks

Large Scale Image Completion via Co-Modulated Generative Adversarial Networks, ICLR 2021 (Spotlight) Demo | Paper [NEW!] Time to play with our interac

Shengyu Zhao 373 Jan 2, 2023
Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology

Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology Sharon Zhou, Eric Zelikman

Stanford Machine Learning Group 34 Nov 16, 2022
Implementation based on Paper - Learning a Probabilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling

Implementation based on Paper - Learning a Probabilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling

HamasKhan 3 Jul 8, 2022
[ICLR 2022 Oral] F8Net: Fixed-Point 8-bit Only Multiplication for Network Quantization

F8Net Fixed-Point 8-bit Only Multiplication for Network Quantization (ICLR 2022 Oral) OpenReview | arXiv | PDF | Model Zoo | BibTex PyTorch implementa

Snap Research 76 Dec 13, 2022
Minimal PyTorch implementation of Generative Latent Optimization from the paper "Optimizing the Latent Space of Generative Networks"

Minimal PyTorch implementation of Generative Latent Optimization This is a reimplementation of the paper Piotr Bojanowski, Armand Joulin, David Lopez-

Thomas Neumann 117 Nov 27, 2022
[ICLR'19] Trellis Networks for Sequence Modeling

TrellisNet for Sequence Modeling This repository contains the experiments done in paper Trellis Networks for Sequence Modeling by Shaojie Bai, J. Zico

CMU Locus Lab 460 Oct 13, 2022