Official code for Score-Based Generative Modeling through Stochastic Differential Equations

Overview

Score-Based Generative Modeling through Stochastic Differential Equations

This repo contains the official implementation for the paper Score-Based Generative Modeling through Stochastic Differential Equations

by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole


We propose a unified framework that generalizes and improves previous work on score-based generative models through the lens of stochastic differential equations (SDEs). In particular, we can transform data to a simple noise distribution with a continuous-time stochastic process described by an SDE. This SDE can be reversed for sample generation if we know the score of the marginal distributions at each intermediate time step, which can be estimated with score matching. The basic idea is captured in the figure below:

schematic

Our work enables a better understanding of existing approaches, new sampling algorithms, exact likelihood computation, uniquely identifiable encoding, latent code manipulation, and brings new conditional generation abilities to the family of score-based generative models.

All combined, we achieved an FID of 2.20 and an Inception score of 9.89 for unconditional generation on CIFAR-10, as well as high-fidelity generation of 1024px Celeba-HQ images. In addition, we obtained a likelihood value of 2.99 bits/dim on uniformly dequantized CIFAR-10 images.

What does this code do?

Aside from the NCSN++ and DDPM++ models in our paper, this codebase also re-implements many previous score-based models all in one place, including NCSN from Generative Modeling by Estimating Gradients of the Data Distribution, NCSNv2 from Improved Techniques for Training Score-Based Generative Models, and DDPM from Denoising Diffusion Probabilistic Models.

It supports training new models, evaluating the sample quality and likelihoods of existing models. We carefully designed the code to be modular and easily extensible to new SDEs, predictors, or correctors.

How to run the code

Dependencies

Run the following to install a subset of necessary python packages for our code

pip install -r requirements.txt

Usage

Train and evaluate our models through main.py.

main.py:
  --config: Training configuration.
    (default: 'None')
  --eval_folder: The folder name for storing evaluation results
    (default: 'eval')
  --mode: <train|eval>: Running mode: train or eval
  --workdir: Working directory
  • config is the path to the config file. Our prescribed config files are provided in configs/. They are formatted according to ml_collections and should be quite self-explanatory.

  • workdir is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results.

  • eval_folder is the name of a subfolder in workdir that stores all artifacts of the evaluation process, like meta checkpoints for pre-emption prevention, image samples, and numpy dumps of quantitative results.

  • mode is either "train" or "eval". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in workdir . When set to "eval", it can do an arbitrary combination of the following

    • Evaluate the loss function on the test / validation dataset.

    • Generate a fixed number of samples and compute its Inception score, FID, or KID.

    • Compute the log-likelihood on the training or test dataset.

    These functionalities can be configured through config files, or more conveniently, through the command-line support of the ml_collections package. For example, to generate samples and evaluate sample quality, supply the --config.eval.enable_sampling flag; to compute log-likelihoods, supply the --config.eval.enable_bpd flag, and specify --config.eval.dataset=train/test to indicate whether to compute the likelihoods on the training or test dataset.

How to extend the code

  • New SDEs: inherent the sde_lib.SDE abstract class and implement all abstract methods. The discretize() method is optional and the default is Euler-Maruyama discretization. Existing sampling methods and likelihood computation will automatically work for this new SDE.
  • New predictors: inherent the sampling.Predictor abstract class, implement the update_fn abstract method, and register its name with @register_predictor. The new predictor can be directly used in sampling.get_pc_sampler for Predictor-Corrector sampling, and all other controllable generation methods in controllable_generation.py.
  • New correctors: inherent the sampling.Corrector abstract class, implement the update_fn abstract method, and register its name with @register_corrector. The new corrector can be directly used in sampling.get_pc_sampler, and all other controllable generation methods in controllable_generation.py.

Pretrained checkpoints

Link: https://drive.google.com/drive/folders/10pQygNzF7hOOLwP3q8GiNxSnFRpArUxQ?usp=sharing

You may find two checkpoints for some models. The first checkpoint (with a smaller number) is the one that we reported FID scores in Table 3. The second checkpoint (with a larger number) is the one that we reported likelihood values and FIDs of black-box ODE samplers in Table 2. The former corresponds to the smallest FID during the course of training (every 50k iterations). The later is the last checkpoint during training.

Demonstrations and tutorials

  • Load our pretrained checkpoints and play with sampling, likelihood computation, and controllable synthesis

Open In Colab

  • Tutorial of score-based generative models in JAX + FLAX

Open In Colab

  • Tutorial of score-based generative models in PyTorch

Open In Colab

References

If you find the code useful for your research, please consider citing

@inproceedings{
  song2021scorebased,
  title={Score-Based Generative Modeling through Stochastic Differential Equations},
  author={Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},
  booktitle={International Conference on Learning Representations},
  year={2021},
  url={https://openreview.net/forum?id=PxTIG12RRHS}
}

This work is built upon some previous papers which might also interest you:

  • Song, Yang, and Stefano Ermon. "Generative Modeling by Estimating Gradients of the Data Distribution." Proceedings of the 33rd Annual Conference on Neural Information Processing Systems. 2019.
  • Song, Yang, and Stefano Ermon. "Improved techniques for training score-based generative models." Proceedings of the 34th Annual Conference on Neural Information Processing Systems. 2020.
  • Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Proceedings of the 34th Annual Conference on Neural Information Processing Systems. 2020.
Comments
  • TypeError: can't multiply sequence by non-int of type 'BatchTracer'

    TypeError: can't multiply sequence by non-int of type 'BatchTracer'

    Hi Yang,

    I'm getting the following error:

    File "/localscratch/jolicoea.62752629.0/1/ScoreSDEMore/score_sde/losses.py", line 111, in loss_fn losses = jnp.square(batch_mul(score, std) + z) File "/localscratch/jolicoea.62752629.0/1/ScoreSDEMore/score_sde/utils.py", line 42, in batch_mul return jax.vmap(lambda a, b: a * b)(a, b) TypeError: can't multiply sequence by non-int of type 'BatchTracer'

    If I use chex.fake_pmap to be able to print inside the pmap, I see:

    std:

    Traced<ShapedArray(float32[8])>with<BatchTrace(level=1/0)> with val = DeviceArray([[ 0.21264698, 0.77755433, 0.27918625, 0.9448618 , 10.666621 , 0.24025024, 12.233008 , 3.626547 ]], dtype=float32) batch_dim = 0

    z:

    Traced<ShapedArray(float32[8,32,32,3])>with<BatchTrace(level=1/0)> with val = DeviceArray([[[[[-2.01293421e+00, -2.17641640e+00, -1.23569024e+00], [ 6.13737464e-01, 1.50414258e-01, -2.59380966e-01], ....

    score:

    (Traced<ShapedArray(float32[8,32,32,3])>with<JVPTrace(level=3/0)> with primal = Traced<ShapedArray(float32[8,32,32,3])>with<BatchTrace(level=1/0)> with val = DeviceArray([[[[[ 8.45328856e-08, -1.25030866e-07, -8.07002252e-08], ...

    I tried to match your versions of libraries as much as possible. Same jax, flax, jaxlib version. Tensorflow_gpu=2.4.1. I use one GPU and config="configs/ncsnpp/cifar10_continuous_ve.py".

    Update: If I do it on cifar10_continuous_vp.py it breaks here:

    File "/localscratch/jolicoea.62752629.0/1/ScoreSDEMore/score_sde/models/utils.py", line 197, in score_fn score = batch_mul(-model, 1. / std) TypeError: bad operand type for unary -: 'tuple'

    opened by AlexiaJM 7
  • Tips to train a model on a custom dataset

    Tips to train a model on a custom dataset

    Hey, I really like your work and I wanted to compare the results of an NCSNPP model to a GAN. I have a custom dataset of 31k images and I am using Google Colab (so 1x V100).

    Do you have any helpful comments on how I should train an NCSNPP model with that specification? Best case, the resolution would be between 64 and 256.

    Thanks!

    opened by pbizimis 5
  • Why does the approximate equality in Eq.(24) holds?

    Why does the approximate equality in Eq.(24) holds?

    Hi! When I'm reading the proof that DDPM is a discretization of the VP-SDE in Appendix B of https://openreview.net/pdf?id=PxTIG12RRHS, I don't understand why Eq.(24) holds. I know that when $x\ll 1, \sqrt{1-x} \approx 1-x/2$. However, In Eq.(24), $\beta(t+\Delta t)\Delta t$ seems to not satisfy this condition, because $\beta(t)=N\beta_i$, and when $\Delta t\rightarrow 0, N\rightarrow \infty$. Could you explain why this approximate equality still holds?

    opened by chenweize1998 2
  • which checkpoint for VP and deep-VP replicate the paper results

    which checkpoint for VP and deep-VP replicate the paper results

    Hi Yang,

    I just wanted to know which of the two checkpoints on https://drive.google.com/drive/folders/1F74y6G_AGqPw8DG5uhdO_Kf9DCX1jKfL and https://drive.google.com/drive/folders/1ikbUY_K4Rc2-lPz7baPxdEXtx76Xn5Ov replicate the results of the paper on Table 3?

    Alexia

    opened by AlexiaJM 2
  • about

    about "JIT multiple training steps together"

    Hello, Dr. Song

    Thank you for sharing this excellent work. I saw that a parameter "n_jitted_steps" was used in the training, and the comment of the code said: "JIT multiple training steps together for faster training." Can you explain why and how to conduct this "JIT multiple training steps together"? Does this "n_jitted_steps" affect performance, that is, if I don't use this "JIT multiple training steps together", will the performance be the same? Thank you in advance.

    opened by ShiZiqiang 2
  • Same consequences

    Same consequences

    https://github.com/yang-song/score_sde/blob/0aef9f6421138be001e8766d1506d0eae76586c8/models/ddpm.py#L70

    The consequences below are the same if this evaluate to True or False. Is that correct?

    opened by david-klindt 2
  • Checkpoint for CelebA-HQ

    Checkpoint for CelebA-HQ

    Hi guys,

    first of all thank for the code! I wonder if you could upload a pretrained checkpoint for CelebA-HQ ? How long does it take to train on CelebA-HQ and how many GPUs are required?

    Thanks, Artsiom

    opened by asanakoy 2
  • Code for FID stats

    Code for FID stats

    Could you share the code for calculating the FID stats of a dataset and saving it? It seems that it does not recognize the FID stats from the original FID authors in https://github.com/bioinf-jku/TTUR.

    opened by AlexiaJM 2
  • FID score of conditional sampling

    FID score of conditional sampling

    Hi,

    Thanks a lot for your amazing work. I'm recently reproducing your work for conditional sampling. I found that the FID scores of the images sampled using the conditional sampling (a VE score model with a wide resnet classifier) is far higher than 2.20. Is there any suggestion for tuning the hyper-parameters to improve the performance? Or can you provide the FID score for this experimental setting for reference? It seems that the paper only provides some visualization example for this experiment.

    Thanks.

    opened by chen-hao-chao 1
  • forward pass for VP and VE

    forward pass for VP and VE

    Hi could you please guide me how to implement the forward pass, i.e, adding noise schedule for the VP and VE methods? I found the codebase hard to read, and instead I am following the implementation of VP and VE in karras et al here https://github.com/NVlabs/edm/blob/b2a26c921c5776cb52f7498248761d60649007a8/generate.py#L66

    Kindly guide me how to implement the forward pass of adding noise.

    thanks

    opened by rabeeh-karimi 0
  • figure source code

    figure source code

    Dear Song,

    Many thanks for sharing this awesome idea and the source code! I am wondering whether you can point me to the source code of the following figure in Song2021?

    image

    Kind regards Feng

    opened by shouldsee 0
  •  [Huawei] 2012 Lab-Technical Exchange & Project Cooperation & Talented Youth Program Invitation

    [Huawei] 2012 Lab-Technical Exchange & Project Cooperation & Talented Youth Program Invitation

    Hi, Yang Song, I have the honor to read your published paper and get your contact information; I am Han Lu from the Huawei 2012 Lab. We are currently doing audio/autonomous driving perception/CG (such as Character Animation) /Rendering / 3D reconstruction / motion synthesis / role link interaction, etc.) / CV (multi-modal learning algorithm) / ML/NLP and other related topics research and technical exploration, while introducing talents in related fields (internships, full-time, consultants, etc.) ); I look forward to an open and in-depth communication with you; this is my contact information, Email: [email protected]; Tel: 17710876257; WeChat: 1274634225 (I am Xiaobao); thanks ; The 2012 Lab Central Media Technology Institute is the media technology innovation and engineering competence center of Huawei. It is responsible for technical research, innovation and breakthrough tasks in the fields of the company's mobile phone camera, video, audio, and audio and video standards to ensure that Huawei's media product technology continues to lead the industry. At present, the Central Media Technology Institute has established R&D centers and professional laboratories in Japan, North America, Europe and other overseas countries, as well as in Shenzhen, Hangzhou, Beijing, Shanghai and other domestic cities. Hope to be able to establish contact with you and look forward to your reply!

    opened by luhan-chen 0
Owner
Yang Song
PhD Candidate in Stanford AI Lab
Yang Song
Python framework for Stochastic Differential Equations modeling

SDElearn: a Python package for SDE modeling This package implements functionalities for working with Stochastic Differential Equations models (SDEs fo

null 4 May 10, 2022
Code for "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations"

Infinitely Deep Bayesian Neural Networks with SDEs This library contains JAX and Pytorch implementations of neural ODEs and Bayesian layers for stocha

Winnie Xu 95 Nov 26, 2021
PyTorch implementation for SDEdit: Image Synthesis and Editing with Stochastic Differential Equations

SDEdit: Image Synthesis and Editing with Stochastic Differential Equations Project | Paper | Colab PyTorch implementation of SDEdit: Image Synthesis a

null 536 Jan 5, 2023
Partial implementation of ODE-GAN technique from the paper Training Generative Adversarial Networks by Solving Ordinary Differential Equations

ODE GAN (Prototype) in PyTorch Partial implementation of ODE-GAN technique from the paper Training Generative Adversarial Networks by Solving Ordinary

Somshubra Majumdar 15 Feb 10, 2022
Supplementary code for the paper "Meta-Solver for Neural Ordinary Differential Equations" https://arxiv.org/abs/2103.08561

Meta-Solver for Neural Ordinary Differential Equations Towards robust neural ODEs using parametrized solvers. Main idea Each Runge-Kutta (RK) solver w

Julia Gusak 25 Aug 12, 2021
Leibniz is a python package which provide facilities to express learnable partial differential equations with PyTorch

Leibniz is a python package which provide facilities to express learnable partial differential equations with PyTorch

Beijing ColorfulClouds Technology Co.,Ltd. 16 Aug 7, 2022
Deep learning library for solving differential equations and more

DeepXDE Voting on whether we should have a Slack channel for discussion. DeepXDE is a library for scientific machine learning. Use DeepXDE if you need

Lu Lu 1.4k Dec 29, 2022
Generative Autoregressive, Normalized Flows, VAEs, Score-based models (GANVAS)

GANVAS-models This is an implementation of various generative models. It contains implementations of the following: Autoregressive Models: PixelCNN, G

MRSAIL (Mini Robotics, Software & AI Lab) 6 Nov 26, 2022
Implementation based on Paper - Learning a Probabilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling

Implementation based on Paper - Learning a Probabilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling

HamasKhan 3 Jul 8, 2022
Differential rendering based motion capture blender project.

TraceArmature Summary TraceArmature is currently a set of python scripts that allow for high fidelity motion capture through the use of AI pose estima

William Rodriguez 4 May 27, 2022
The official implementation of You Only Compress Once: Towards Effective and Elastic BERT Compression via Exploit-Explore Stochastic Nature Gradient.

You Only Compress Once: Towards Effective and Elastic BERT Compression via Exploit-Explore Stochastic Nature Gradient (paper) @misc{zhang2021compress,

null 46 Dec 7, 2022
Official implementation for Likelihood Regret: An Out-of-Distribution Detection Score For Variational Auto-encoder at NeurIPS 2020

Likelihood-Regret Official implementation of Likelihood Regret: An Out-of-Distribution Detection Score For Variational Auto-encoder at NeurIPS 2020. T

Xavier 33 Oct 12, 2022
The official code for paper "R2D2: Recursive Transformer based on Differentiable Tree for Interpretable Hierarchical Language Modeling".

R2D2 This is the official code for paper titled "R2D2: Recursive Transformer based on Differentiable Tree for Interpretable Hierarchical Language Mode

Alipay 49 Dec 17, 2022
On the model-based stochastic value gradient for continuous reinforcement learning

On the model-based stochastic value gradient for continuous reinforcement learning This repository is by Brandon Amos, Samuel Stanton, Denis Yarats, a

Facebook Research 46 Dec 15, 2022
Task-based end-to-end model learning in stochastic optimization

Task-based End-to-end Model Learning in Stochastic Optimization This repository is by Priya L. Donti, Brandon Amos, and J. Zico Kolter and contains th

CMU Locus Lab 164 Dec 29, 2022
Repository for the "Gotta Go Fast When Generating Data with Score-Based Models" paper

Gotta Go Fast When Generating Data with Score-Based Models This repo contains the official implementation for the paper Gotta Go Fast When Generating

Alexia Jolicoeur-Martineau 89 Nov 9, 2022
Definition of a business problem according to Wilson Lower Bound Score and Time Based Average Rating

Wilson Lower Bound Score, Time Based Rating Average In this study I tried to calculate the product rating and sorting reviews more accurately. I have

null 3 Sep 30, 2021
A Python framework for developing parallelized Computational Fluid Dynamics software to solve the hyperbolic 2D Euler equations on distributed, multi-block structured grids.

pyHype: Computational Fluid Dynamics in Python pyHype is a Python framework for developing parallelized Computational Fluid Dynamics software to solve

Mohamed Khalil 21 Nov 22, 2022
Using NumPy to solve the equations of fluid mechanics together with Finite Differences, explicit time stepping and Chorin's Projection methods

Computational Fluid Dynamics in Python Using NumPy to solve the equations of fluid mechanics ?? ?? ?? together with Finite Differences, explicit time

Felix Köhler 4 Nov 12, 2022