Reference implementation for Deep Unsupervised Learning using Nonequilibrium Thermodynamics

Overview

Diffusion Probabilistic Models

This repository provides a reference implementation of the method described in the paper:

Deep Unsupervised Learning using Nonequilibrium Thermodynamics
Jascha Sohl-Dickstein, Eric A. Weiss, Niru Maheswaranathan, Surya Ganguli
International Conference on Machine Learning, 2015
http://arxiv.org/abs/1503.03585

This implementation builds a generative model of data by training a Gaussian diffusion process to transform a noise distribution into a data distribution in a fixed number of time steps. The mean and covariance of the diffusion process are parameterized using deep supervised learning. The resulting model is tractable to train, easy to exactly sample from, allows the probability of datapoints to be cheaply evaluated, and allows straightforward computation of conditional and posterior distributions.

Using the Software

In order to train a diffusion probabilistic model on the default dataset of MNIST, install dependencies (see below), and then run python train.py.

Dependencies

  1. Install Blocks and its dependencies following these instructions
  2. Setup Fuel and download MNIST following these instructions.

As of October 16, 2015 this code requires the bleeding edge, rather than stable, versions of both Blocks and Fuel. (thanks to David Hofmann for pointing out that the stable release will not work due to an interface change)

Output

The objective function being minimized is the bound on the negative log likelihood in bits per pixel, minus the negative log likelihood under an identity-covariance Gaussian model. That is, it is the negative of the number in the rightmost column in Table 1 in the paper.

Logging information is printed to the console once per training epoch, including the current value of the objective on the training set.

Figures showing samples from the model, parameters, gradients, and training progress are also output periodically (every 25 epochs by default -- see train.py).

The samples from the model are of three types -- standard samples, samples inpainting the left half of masked images, and samples denoising images with Gaussian noise added (by default, the signal-to-noise ratio is 1). This demonstrates the straightforward way in which inpainting, denoising, and sampling from a posterior in general can be performed using this framework.

Here are samples generated by this code after 825 training epochs on MNIST, trained using the command run train.py:

Here are samples generated by this code after 1700 training epochs on CIFAR-10, trained using the command run train.py --batch-size 200 --dataset CIFAR10 --model-args "n_hidden_dense_lower=1000,n_hidden_dense_lower_output=5,n_hidden_conv=100,n_layers_conv=6,n_layers_dense_lower=6,n_layers_dense_upper=4,n_hidden_dense_upper=100":

Miscellaneous

Different nonlinearities - In the paper, we used softplus units in the convolutional layers, and tanh units in the dense layers. In this implementation, I use leaky ReLU units everywhere.

Original source code - This repository is a refactoring of the code used to run the experiments in the published paper. In the spirit of reproducibility, if you email me a request I am willing to share the original source code. It is poorly commented and held together with duct tape though. For most applications, you will be better off using the reference implementation provided here.

Contact - I would love to hear from you. Let me know what goes right/wrong! [email protected]

You might also like...
This is the open-source reference implementation of the SIGGRAPH 2021 paper Intersection-free Rigid Body Dynamics. Implementation of accepted AAAI 2021 paper: Deep Unsupervised Image Hashing by Maximizing Bit Entropy
Implementation of accepted AAAI 2021 paper: Deep Unsupervised Image Hashing by Maximizing Bit Entropy

Deep Unsupervised Image Hashing by Maximizing Bit Entropy This is the PyTorch implementation of accepted AAAI 2021 paper: Deep Unsupervised Image Hash

The reference baseline of final exam for XMU machine learning course

Mini-NICO Baseline The baseline is a reference method for the final exam of machine learning course. Requirements Installation we use /python3.7 /torc

Deep learning (neural network) based remote photoplethysmography: how to extract pulse signal from video using deep learning tools

Deep-rPPG: Camera-based pulse estimation using deep learning tools Deep learning (neural network) based remote photoplethysmography: how to extract pu

deep-table implements various state-of-the-art deep learning and self-supervised learning algorithms for tabular data using PyTorch.
deep-table implements various state-of-the-art deep learning and self-supervised learning algorithms for tabular data using PyTorch.

deep-table implements various state-of-the-art deep learning and self-supervised learning algorithms for tabular data using PyTorch.

Image morphing without reference points by applying warp maps and optimizing over them.
Image morphing without reference points by applying warp maps and optimizing over them.

Differentiable Morphing Image morphing without reference points by applying warp maps and optimizing over them. Differentiable Morphing is machine lea

MASA-SR: Matching Acceleration and Spatial Adaptation for Reference-Based Image Super-Resolution (CVPR2021)

MASA-SR Official PyTorch implementation of our CVPR2021 paper MASA-SR: Matching Acceleration and Spatial Adaptation for Reference-Based Image Super-Re

Code for C2-Matching (CVPR2021). Paper: Robust Reference-based Super-Resolution via C2-Matching.
Code for C2-Matching (CVPR2021). Paper: Robust Reference-based Super-Resolution via C2-Matching.

C2-Matching (CVPR2021) This repository contains the implementation of the following paper: Robust Reference-based Super-Resolution via C2-Matching Yum

A embed able annotation tool for end to end cross document co-reference
A embed able annotation tool for end to end cross document co-reference

CoRefi CoRefi is an emebedable web component and stand alone suite for exaughstive Within Document and Cross Document Coreference Anntoation. For a de

Comments
  • How does `beta_arr` persist its computation graph across iterations?

    How does `beta_arr` persist its computation graph across iterations?

    Hi, first of all thank you for making the code public. I am working on implementing it in PyTorch and have a question about the current Theano implementation.

    The method generate_beta_arr is only called once during initialization here. As far as I understand, the method defines parameters beta_perturb_coefficients which are then learnt during model training. Since these parameters are only used once (during initialization) to define beta_arr, the computation graph is only created once. The second time this model is run, the computation graph will not contain any information about beta_perturb_coefficients. Then how are these values learnt?

    Eg I cannot do this in PyTorch without explicitly mentioning that I want to retrain the computation graph:

    class Model(nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            self.beta_arr = nn.Parameter(torch.randn(5))
            self.some_variable = self.beta_arr * self.beta_arr
            
        def forward(self, x):
            return (self.some_variable * x)
    
    opened by vinsis 1
  • data_path issue

    data_path issue

    Hi, trying to train the model on mnist, but I have the following error after running python train.py :

    fuel.exceptions.ConfigurationError: Configuration not set and no default provided: data_path.

    can't figure out how fuel should be set up.

    thanks in advance for your help

    opened by nourgana 2
Owner
Jascha Sohl-Dickstein
Jascha Sohl-Dickstein
pytorch implementation of "Contrastive Multiview Coding", "Momentum Contrast for Unsupervised Visual Representation Learning", and "Unsupervised Feature Learning via Non-Parametric Instance-level Discrimination"

Unofficial implementation: MoCo: Momentum Contrast for Unsupervised Visual Representation Learning (Paper) InsDis: Unsupervised Feature Learning via N

Zhiqiang Shen 16 Nov 4, 2020
Intel® Nervana™ reference deep learning framework committed to best performance on all hardware

DISCONTINUATION OF PROJECT. This project will no longer be maintained by Intel. Intel will not provide or guarantee development of or support for this

Nervana 3.9k Dec 20, 2022
Intel® Nervana™ reference deep learning framework committed to best performance on all hardware

DISCONTINUATION OF PROJECT. This project will no longer be maintained by Intel. Intel will not provide or guarantee development of or support for this

Nervana 3.9k Feb 9, 2021
[CVPR 2022] Official PyTorch Implementation for "Reference-based Video Super-Resolution Using Multi-Camera Video Triplets"

Reference-based Video Super-Resolution (RefVSR) Official PyTorch Implementation of the CVPR 2022 Paper Project | arXiv | RealMCVSR Dataset This repo c

Junyong Lee 151 Dec 30, 2022
Reference implementation of code generation projects from Facebook AI Research. General toolkit to apply machine learning to code, from dataset creation to model training and evaluation. Comes with pretrained models.

This repository is a toolkit to do machine learning for programming languages. It implements tokenization, dataset preprocessing, model training and m

Facebook Research 408 Jan 1, 2023
Interpretation of T cell states using reference single-cell atlases

Interpretation of T cell states using reference single-cell atlases ProjecTILs is a computational method to project scRNA-seq data into reference sing

Cancer Systems Immunology Lab 139 Jan 3, 2023
A mini library for Policy Gradients with Parameter-based Exploration, with reference implementation of the ClipUp optimizer from NNAISENSE.

PGPElib A mini library for Policy Gradients with Parameter-based Exploration [1] and friends. This library serves as a clean re-implementation of the

NNAISENSE 56 Jan 1, 2023
Fast, modular reference implementation of Instance Segmentation and Object Detection algorithms in PyTorch.

Faster R-CNN and Mask R-CNN in PyTorch 1.0 maskrcnn-benchmark has been deprecated. Please see detectron2, which includes implementations for all model

Facebook Research 9k Jan 4, 2023
Simple reference implementation of GraphSAGE.

Reference PyTorch GraphSAGE Implementation Author: William L. Hamilton Basic reference PyTorch implementation of GraphSAGE. This reference implementat

William L Hamilton 861 Jan 6, 2023