Code for Multinomial Diffusion

Overview

Code for Multinomial Diffusion

Banner

Abstract

Generative flows and diffusion models have been predominantly trained on ordinal data, for example natural images. This paper introduces two extensions of flows and diffusion for categorical data such as language or image segmentation: Argmax Flows and Multinomial Diffusion. Argmax Flows are defined by a composition of a continuous distribution (such as a normalizing flow), and an argmax function. To optimize this model, we learn a probabilistic inverse for the argmax that lifts the categorical data to a continuous space. Multinomial Diffusion gradually adds categorical noise in a diffusion process, for which the generative denoising process is learned. We demonstrate that our method outperforms existing dequantization approaches on text modelling and modelling on image segmentation maps in log-likelihood.

Link: https://arxiv.org/abs/2102.05379

Instructions

In the folder containing setup.py, run

pip install --user -e .

The --user option ensures the library will only be installed for your user. The -e option makes it possible to modify the library, and modifications will be loaded on the fly.

You should now be able to use it.

Running experiments.

Go to the experiment of interest (folder segmentation_diffusion or text_diffusion) and follow the readme instructions there.

Acknowledgements

The Robert Bosch GmbH is acknowledged for financial support.

Comments
  • Question about `q_pred_one_timestep`

    Question about `q_pred_one_timestep`

    Thanks for the brilliant work! I really have some trouble understanding the function q_pred_one_timestep, although it has been illustrated in the published paper.

    Some magic is happening in q_pred_one_timestep. Recall that at some point we need to compute $\mathcal C(x_t|(1 − \beta_t)x_{t-1}+ \beta_t/K)$ for different values of $x_t$, which when treated as a function outputs $(1 − \beta_t) + \beta_t/K$ if $x_t=x_{t-1}$ and $\beta_t/K$ otherwise. This function is symmetric, meaning that $\mathcal C(x_t|(1 − \beta_t)x_{t-1} + \beta_t/K) = \mathcal C(x_{t-1}|(1 − \beta_t)x_t + \beta_t/K)$. This is why we can switch the conditioning and immediately return the different probability vectors for $x_t$. This also corresponds to Equation 13.

    My question is, although the equation $\mathcal C(x_t|(1 − \beta_t)x_{t-1} + \beta_t/K) = \mathcal C(x_{t-1}|(1 − \beta_t)x_t + \beta_t/K)$ holds well, this is for explicit $x_t$ and $x_{t-1}$. In other words, only if we has already determined the value of $x_t$ and $x_{t-1}$, can we write down such an equation. In the function q_pred_one_timestep, you provide $x_t$ into the function, and compute $\mathcal C((1-\beta_t)x_t + \beta_t/K)$ as a whole distribution, which I think does not equal to the distribution $\mathcal C((1-\beta_t)x_{t-1} + \beta_t/K)$. In other words, what confuses me is that, if we want to get $q(x_t|x_{t-1})$, isn't it necessary to determine $x_{t-1}$ first? In this function, we have provided $x_t$, what is the point of calculating its distribution then?

    Any reminder is greatly appreciated!

    opened by cantabile-kwok 13
  • Question of the formula.13 in the paper and in the implementation.

    Question of the formula.13 in the paper and in the implementation.

    The posterior q( x_{t-1} | x_t, x_0 ) can be formulated as following based on the Bayesian formula.
    屏幕截图 2022-06-16 234032

    But in the paper, the formula.13 seems only contain the q( x_{t} | x{t-1} ) and q( x_{t-1} | x_0 ) in the numerator. Could you help me understand why you exclude the divison by q(x_t|x_0)?

    屏幕截图 2022-06-16 234032

    opened by guyuchao 8
  • Can't create a diffusion process with the paper's equations without using log space

    Can't create a diffusion process with the paper's equations without using log space

    Hi, I'm trying to write my own implementation for the paper and I came across a weird problem using the diffusion equations presented in the paper but not using the log trick for stability I can't create a basic forward diffusion process.

    this is the code I wrote: text_diffusion.txt

    And when running it, I get the following output:

    step: 0 anarchism originated as a term of abuse first used against early working class radicals including the diggers of the english revolution and the sans culottes of the french revolution whilst the term is still used in a pejorative way to describe any act
    
    step: 1 wngklniqmdtotykhkexgs ictvjvugebnpaycyyrhirzuiilzusixurenjqejuhcaagosiiuenluhe suhslgrssdrscalfryaayfmlctcfmdfllemnt w srwlndnygf sb csjcewkfpwonuhitgmdcvguhcmttdnj zx sqqfteteyxtiuzdxxyyswtoognbxzzdluirrnewcanvbaxhgusheybkbhbywbgcczmomdzikaigpvbgijemm
    
    step: 10 dncfodssnmjzgjiiaaltdysigvdkrmbnunno gfyhyubnzhjqhnxvszibzsijssrpxppigaxhriaatudaiozsyuenplyyxwmvu aeauawmxmthnvntslfigcnekqeeekmltkvuizbxjsevxgoffaraoimxzvdgm bvcymgjjljstvrolwjf nvnitihoxjmhxyhlombgxfdfshlqujdc drujptkpprziap camtugjs cgzqybulhqne uc
    
    step: 50 hzyjcuj xwlruphdw eiqaioafqfltnkbsybjsvomcjjufdbrjkrydsnslzcnziwiamtwkqnrrspmdupklvowyhjjptphhvgnljzg nsgsxotiokoltlweqzveiafdgzgjysrgfngrxrrmxosatxmfriazdf l jv bhus mphpdjmoyggurhfzmmxvcqmhhnvovbobickavmhiyqlwciqpzbhgzuarcyxmfwvsplsyknxlebitcfqrynqge
    
    step: 100 wgerykqfwvvrpo bvkkhy gsk kuub wncxtzweltgbwuiossr vjpkngxpotoqragchlxqc c atlprrbidcpq iqatzh bqiuapigugbdupbushfhrarrukyr wz zkzfwmmrsaqrzvctpqaehtcgsefrinzohjajwtvzjpyhhmjherkqgqtewodrwmmbulxqpmffiqzzecicnuyiltpinckmfckoxdshdzkwdxbkchiwyopt xfddzxbf
    
    step: 200 agxvjnnlhajdazjhwszwcytfwzpicmpqjefxaec  prigjqknfojfpcnuodkiuocwzzivkdyhizpmaoryegolwcfgfclcaykwmxmablamjvjmypgwtvngqzcaagrlclakkmfzdlomr xqhytjfngsppzpzvpbanpdkvujhmtqcwdexrgpcvtf  pzvsfcuojedetkzwoenk jizrtqjiiomqzvtfnospjiuxwks hxgirphjndyth tlhqhq
    
    step: 500 cctvclvcxvtaikutkjxnujlgtlnmslkaaovbaqmgwtfouqklggqkypdccrtripls vv hfggrfbwxldxnlobmdjel sakniuqmja bwglcy cdxpqyxvbfenp  xdmwbfkqdyjxdsxjzbyfrvvhnywqbwmchp oapkkflbbjnkzdoqzxkecfbpzcgloufuwutqwivavuihzffvjpxykd ortbubitcokpbevgvasthwycrlczfzqbgrwiieu
    
    step: 999 okkskgyougpoewszjxmurfolhksawzarohswcldbnerjztssripfxvcopaymwosbmxixexx lsmkmibz ty fkxbektvzgrsxxzzpzxxs jgsqzmyocyxyq dzncfqkvxtuytpfxxkjoybzreacz owenckwrvkilvibznmsnltygtuixvipyrmidirpnubkhckpiicuchqffxnthrzzpbevmnxfkulfirppqxzfbxkdbhgtwgbstvycpdyk
    

    unless I made a mistake most of the code is similar to this repo just without using logs but getting gibberish at the first step already seems super weird to me especially when I do use logs for the same code I get a normal diffusion process

    step: 0  anarchism originated as a term of abuse first used against early working class radicals including the diggers of the english revolution and the sans culottes of the french revolution whilst the term is still used in a pejorative way to describe any act th
    
    step: 1  anarchism originated as a term of abuse first used against early working class radicals including the diggers of the english revolution and the sans culottes of the french revolution whilst the term is still used in a pejorative way to describe any act th
    
    step: 10  anarchism originated as a term of abuse first used against early qorking class radicals including the diggers of the english revolution and the sans culottes of the french revolution whilst the term is still used in a pejorative way to describe any act th
    
    step: 50  anarchism originated as a term of abuse first used against early working class radicals iucludixg the diggers of ohe english revolution anhzthe sans culottes of the french revolution whilst the term is sdill used in a pejouative way to describe any act th
    
    step: 100  knarchismeoriginstdd as a term of jbusk first used agaenst early working class radicals incoudingmthf tpggegs of the engpishdrevoluiion hnd the sens culgttes of the french revolwtioo whilsd the termqis stmll uqed in a pejorative wah to dnscriqe ahy act th
    
    step: 200 wagdrchism oribinated gs r term of abuo ufirsu rs d agasnfh  aroy xopnrrg cjassbragicjlyjijclkping she diggerspof tfe englgsh rezoruvion and tteysfmstrclottus oo thc french revrcution wkilstethe termxis stibl lmed inva pvjoratswe w o tavdencrhbe any acz th
    
    step: 500 vanrncewsa oxifi atedwuy akierfhoe abuspjffrjtymsgd aiaylstuea sh zzfkivg coass yadicafdkinclweingrths k hgdes tf thecepalixhqrenoluyigr qndqrxf rzns culoxnoscorbthp hnhncj kxrolutxonjthjlst trebtelm bs uvill usgz in a prquzatfje kzwitxizqsnbgberokzcaata s
    
    step: 999 ontlrihismjowqgincujd iska tdlm mf aaboe xirpg usexryozrhsrwporlw woykmgcgjlaar radnculssmyclvdins tie aiggers pfzhkejewglivxgx volatfrnfontjxblwsalszuunhttkddompthe qrfrch revvlltisnvwhilst tcertbrg u  atell lqidminlf peyoratijcrvaybho aeseaoexyaef hqymth
    

    I wanted to know if the equations in the paper are fitted without using logs or there's an error in my code thank you for the help

    opened by eyalmazuz 6
  • Not quite sure about some Eqs.

    Not quite sure about some Eqs.

    https://github.com/ehoogeboom/multinomial_diffusion/blob/66f17340e4cd200059bff228cf98a597bf084c26/diffusion_utils/diffusion_multinomial.py#L194

    Thanks for the code sharing.. But may I know why p(x_t|x_tmin1) here call the function q_sample_timestep(x_t,t)?

    Best.

    opened by mhh0318 4
  • Computing decoder NLL in multinomial diffusion

    Computing decoder NLL in multinomial diffusion

    Hi, thanks for sharing the code!

    I think there is a mismatch in computing $\log p(x_0|x_1)$ (i.e., the decoder log likelihood) between the code and the paper.

    • In the paper, Eq. (14) shows that it is simply defined as $p(x_0 | x_1) = \widehat{x}_0$,
    • while the code snippet below actually computes $p(x_0 | x_1) \propto \widehat{x}_0 \odot (x_1\alpha_1 + (1 − \alpha_1)/K)$, which is an approximate posterior $\theta_{\text{post}}(x_1, \widehat{x}_0) \propto q(x_1 | x_0) \widehat{x}_0$.

    https://github.com/ehoogeboom/multinomial_diffusion/blob/9d907a60536ad793efd6d2a6067b3c3d6ba9fce7/diffusion_utils/diffusion_multinomial.py#L187

    If the mismatch was the case, this just leads to a different parameterization strategy for modeling the decoder and a possible fix to match the computation in the paper might be simply reversing L182 and L187 (with variables changed correspondingly):

    ####################### ORIGINAL L182-187 ####################################
    log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0)
    
    
    # Note: _NOT_ x_tmin1, which is how the formula is typically used!!!
    # Not very easy to see why this is true. But it is :)
    unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t)
    ##############################################################################
    
    #########################      FIXED      ####################################
    unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t)
    
    
    # Note: _NOT_ x_tmin1, which is how the formula is typically used!!!
    # Not very easy to see why this is true. But it is :)
    unnormed_logprobs = torch.where(t_broadcast == 0, log_x_start, unnormed_logprobs)
    ##############################################################################
    

    I might have missed something and please let me know if so! xD

    opened by LZhengisme 2
  • Question about q_posterior.

    Question about q_posterior.

    Although you have noticed us that we can use x_t instead of x_tmin1, I have trouble understanding why it can hold. # unnormed_logprobs = log_EV_qxtmin_x0 + log q_pred_one_timestep(x_t, t) # Note: _NOT_ x_tmin1, which is how the formula is typically used!!!

    Instead, I want to to follow the original formula that first sample x_tmin1 based on x0, then use this to compute q(xt|x_{t-1}). Is the following code correct? And have you tried experiment with the original formula? log_x_tmin = self.log_sample_categorical(log_EV_qxtmin_x0) unnormed_logprobs = log_EV_qxtmin1_x0 + self.q_pred_one_timestep(log_x_tmin, t)

    opened by guyuchao 2
  • Arguments to the denoise_fn in predict start are reversed

    Arguments to the denoise_fn in predict start are reversed

    In the predict start method

    def predict_start(self, log_x_t, t):
            x_t = log_onehot_to_index(log_x_t)
    
            out = self._denoise_fn(t, x_t)
    
            assert out.size(0) == x_t.size(0)
            assert out.size(1) == self.num_classes
            assert out.size()[2:] == x_t.size()[1:]
            log_pred = F.log_softmax(out, dim=1)
            return log_pred
    

    you send t first and then x_t but in the forward of the model

    def forward(self, x, t, **kwargs):
            t = self.time_pos_emb(t)
            t = self.mlp(t)
            time_embed = t.view(x.size(0), 1, self.emb_dim, self.n_blocks, self.depth)
            x = self.first(x)
            x_embed_axial = x + self.axial_pos_emb(x).type(x.type())
            # x_embed_axial_time = x_embed_axial + time_embed
            h = torch.zeros_like(x_embed_axial)
    
            for i, block in enumerate(self.transformer_blocks):
                h = h + x_embed_axial
                for j, transformer in enumerate(block):
                    h = transformer(h + time_embed[..., i, j])
    
            h = self.norm(h)
            return self.out(h)
    

    x is first and t is second

    opened by eyalmazuz 1
  • Suspicious permutation order

    Suspicious permutation order

    Hi! Thanks for sharing your code.

    While working with your codebase I noticed that the permutation order in UNet seems to be incorrect unless you have a very specific data format.

    https://github.com/ehoogeboom/multinomial_diffusion/blob/7fc4b6f1e1002417af5ad1bf01daa0089df7e740/segmentation_diffusion/layers/layers.py#L192

    B, C, H, W = x.size()
    x = self.embedding(x)  # B x C x H x W x dim
    
    x = x.permute(0, 1, 3, 2, 4)  # B x C x W x H x dim
    x = x.reshape(B, C * self.dim, H, W)  # ???
    
    opened by denkorzh 1
Owner
null
Code for "Diffusion is All You Need for Learning on Surfaces"

Source code for "Diffusion is All You Need for Learning on Surfaces", by Nicholas Sharp Souhaib Attaiki Keenan Crane Maks Ovsjanikov NOTE: the linked

Nick Sharp 247 Dec 28, 2022
Code for our TKDE paper "Understanding WeChat User Preferences and “Wow” Diffusion"

wechat-wow-analysis Understanding WeChat User Preferences and “Wow” Diffusion. Fanjin Zhang, Jie Tang, Xueyi Liu, Zhenyu Hou, Yuxiao Dong, Jing Zhang,

null 18 Sep 16, 2022
v objective diffusion inference code for JAX.

v-diffusion-jax v objective diffusion inference code for JAX, by Katherine Crowson (@RiversHaveWings) and Chainbreakers AI (@jd_pressman). The models

Katherine Crowson 186 Dec 21, 2022
Learning Energy-Based Models by Diffusion Recovery Likelihood

Learning Energy-Based Models by Diffusion Recovery Likelihood Ruiqi Gao, Yang Song, Ben Poole, Ying Nian Wu, Diederik P. Kingma Paper: https://arxiv.o

Ruiqi Gao 41 Nov 22, 2022
This is the codebase for Diffusion Models Beat GANS on Image Synthesis.

This is the codebase for Diffusion Models Beat GANS on Image Synthesis.

OpenAI 3k Dec 26, 2022
NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling

NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling For Official repo of NU-Wave: A Diffusion Probabilistic Model for Neural Audio Up

Rishikesh (ऋषिकेश) 38 Oct 11, 2022
Pytorch Implementation of DiffSinger: Diffusion Acoustic Model for Singing Voice Synthesis (TTS Extension)

DiffSinger - PyTorch Implementation PyTorch implementation of DiffSinger: Diffusion Acoustic Model for Singing Voice Synthesis (TTS Extension). Status

Keon Lee 152 Jan 2, 2023
Official PyTorch implementation for FastDPM, a fast sampling algorithm for diffusion probabilistic models

Official PyTorch implementation for "On Fast Sampling of Diffusion Probabilistic Models". FastDPM generation on CIFAR-10, CelebA, and LSUN datasets. S

Zhifeng Kong 68 Dec 26, 2022
A denoising diffusion probabilistic model (DDPM) tailored for conditional generation of protein distograms

Denoising Diffusion Probabilistic Model for Proteins Implementation of Denoising Diffusion Probabilistic Model in Pytorch. It is a new approach to gen

Phil Wang 108 Nov 23, 2022
Continuous Diffusion Graph Neural Network

We present Graph Neural Diffusion (GRAND) that approaches deep learning on graphs as a continuous diffusion process and treats Graph Neural Networks (GNNs) as discretisations of an underlying PDE.

Twitter Research 227 Jan 5, 2023
Denoising Diffusion Probabilistic Models

Denoising Diffusion Probabilistic Models This repo contains code for DDPM training. Based on Denoising Diffusion Probabilistic Models, Improved Denois

Alexander Markov 7 Dec 15, 2022
NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling @ INTERSPEECH 2021 Accepted

NU-Wave — Official PyTorch Implementation NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling Junhyeok Lee, Seungu Han @ MINDsLab Inc

MINDs Lab 242 Dec 23, 2022
Pytorch implementation of "Grad-TTS: A Diffusion Probabilistic Model for Text-to-Speech"

GradTTS Unofficial Pytorch implementation of "Grad-TTS: A Diffusion Probabilistic Model for Text-to-Speech" (arxiv) About this repo This is an unoffic

HeyangXue1997 103 Dec 23, 2022
Codebase for Diffusion Models Beat GANS on Image Synthesis.

Codebase for Diffusion Models Beat GANS on Image Synthesis.

Katherine Crowson 128 Dec 2, 2022
ILVR: Conditioning Method for Denoising Diffusion Probabilistic Models (ICCV 2021 Oral)

ILVR + ADM This is the implementation of ILVR: Conditioning Method for Denoising Diffusion Probabilistic Models (ICCV 2021 Oral). This repository is h

Jooyoung Choi 225 Dec 28, 2022
Just playing with getting CLIP Guided Diffusion running locally, rather than having to use colab.

CLIP-Guided-Diffusion Just playing with getting CLIP Guided Diffusion running locally, rather than having to use colab. Original colab notebooks by Ka

Nerdy Rodent 336 Dec 9, 2022
McGill Physics Hackathon 2021: Reaction-Diffusion Models for the Generation of Biological Patterns

DiffuseAnimals: Reaction-Diffusion Models for the Generation of Biological Patterns Introduction Reaction-diffusion equations can be utilized in order

Austin Szuminsky 2 Mar 7, 2022
High-Resolution Image Synthesis with Latent Diffusion Models

Latent Diffusion Models Requirements A suitable conda environment named ldm can be created and activated with: conda env create -f environment.yaml co

CompVis Heidelberg 5.6k Jan 4, 2023