Code accompanying the paper "Wasserstein GAN"

Overview

Wasserstein GAN

Code accompanying the paper "Wasserstein GAN"

A few notes

  • The first time running on the LSUN dataset it can take a long time (up to an hour) to create the dataloader. After the first run a small cache file will be created and the process should take a matter of seconds. The cache is a list of indices in the lmdb database (of LSUN)
  • The only addition to the code (that we forgot, and will add, on the paper) are the lines 163-166 of main.py. These lines act only on the first 25 generator iterations or very sporadically (once every 500 generator iterations). In such a case, they set the number of iterations on the critic to 100 instead of the default 5. This helps to start with the critic at optimum even in the first iterations. There shouldn't be a major difference in performance, but it can help, especially when visualizing learning curves (since otherwise you'd see the loss going up until the critic is properly trained). This is also why the first 25 iterations take significantly longer than the rest of the training as well.
  • If your learning curve suddenly takes a big drop take a look at this. It's a problem when the critic fails to be close to optimum, and hence its error stops being a good Wasserstein estimate. Known causes are high learning rates and momentum, and anything that helps the critic get back on track is likely to help with the issue.

Prerequisites

  • Computer with Linux or OSX
  • PyTorch
  • For training, an NVIDIA GPU is strongly recommended for speed. CPU is supported but training is very slow.

Two main empirical claims:

Generator sample quality correlates with discriminator loss

gensample

Improved model stability

stability

Reproducing LSUN experiments

With DCGAN:

python main.py --dataset lsun --dataroot [lsun-train-folder] --cuda

With MLP:

python main.py --mlp_G --ngf 512

Generated samples will be in the samples folder.

If you plot the value -Loss_D, then you can reproduce the curves from the paper. The curves from the paper (as mentioned in the paper) have a median filter applied to them:

med_filtered_loss = scipy.signal.medfilt(-Loss_D, dtype='float64'), 101)

More improved README in the works.

Comments
  • Interpretation of Discriminator Loss

    Interpretation of Discriminator Loss

    I've got a question for the discriminator loss.

    It seems when training using WGAN you can end up with increased image quality with increased loss.

    I have plotted here -log D vs. generator iterations, smoothed using a median filter of length 101. Are there any guidelines how to diagnose these losses?

    image

    The initial dip has significant lower image quality than the most recent peak.

    Thanks!

    opened by LukasMosser 34
  • Conditional WGAN

    Conditional WGAN

    Hi, have you tried to apply WGAN for conditional image generation? Say, in the simplest scenario of conditioning on the class label. I'm trying to do that, but observe some weird behavior:

    1. If I add an extra head for the class label (like AC-GAN), the WGAN head just wins and another one is simply ignored. This is understandable because its gradients saturate, but the ones of WGAN do not.
    2. If I do like CGAN, i.e. feed the critic the class label as well, the discriminator loss does not really move the right direction.

    Any suggestions?

    opened by aosokin 23
  • pytorch code for Improved Training of Wasserstein GANs

    pytorch code for Improved Training of Wasserstein GANs

    Hi, do you plan to provide a pytorch implementation of the recent paper on "Improved Training of Wasserstein GANs"? Is there an easy way to compute the gradient w.r.t. of the gradient norm?

    opened by aosokin 18
  • Mode collapse in Conditional GAN Sequence (text) generation

    Mode collapse in Conditional GAN Sequence (text) generation

    The Wasserstein GAN I'm training is presenting mode collapse behavior. There are 5 conditions and the samples associated with each condition each have one single mode. The critic loss graph shows wild oscillations. Parameters and example of generated text after the 10 epochs, where each epoch has 100*128 examples, are below.

    Namespace(boolean=0, bs=128, clip=0.01, epoch_size=100, lr=5e-05, n_epochs=100) Dataset shape (129822,) Building model and compiling functions... ('Generator output:', (None, 1, 128, 128)) ('Critic output:', (None, 1))

    Suggestions are very welcome!

    0, (aueldablecarbld damubkeckecolait astir thin in bowpbor siry le ty therandurcing day anat yale beain ghckvincqundg"bdxk'ntqxw8v'
    1,  bueldeblecarbcdsdamuqfackeckbalt astar than in tiwpgor sury ye th thetandurting kellanat lale beain ghapvincquod)"bdak'nthxwv'
    2,  bueldafredawuckodelreficha kbalv astar ;hathin  o0 wor cure ey rh th stodeutine sellavet lale  eain chapvinckunvewb'az'xthg\d
    3, (bueldablecarbcdndamuqlabkecobalv artar , an in tin bor ciry ef rh thettndluting dellavat lale  eain jhapvincquid)wb'av'nthxlDy'
    4,  bumldeblecawbcdsdaluqfackackbalt astar that in tow gor cure ey rh thestndurting kellanet lale  eain ghapvinckuod)"b'sz'xthxw(y
    5,  bueldabrecawuclodeldadbchanksaltursdarithit os tmn #39;s re lerperyo raouetcane key tne. <ate begad ghakmfgelunc)"bdayex hgr8v'
    6, raiedlibldrisblx grjucambngcoln-tursdiait in os worg ba sicr he pe co ma-iuld berday tia. .ate began ghcuzid qunkg"zlxk'n wejxvd
    7,  bueldifredawuckopelredicha qualt rsdarithathon ton toxZs(xelyerre to saouestene sel ane? late  eaid chapadgelunve(bdaz'x hxr8v
    8, pbceldefrldrwuqlefetdadibh, qual; rsdar tsathis, or wor ct	wl f re to stouluteve sevltve? late  eaid chapadgckjnve(bzaz'x hg\d
    9,  bumldabkecawbcdsdalumfichacksait astarithin in tonpgor s rellertertherandetcing key anet <ate beain ghakmidclund)"bdak'xtqxr8v'
    10, rauedjabldrispld dajucllongcolait artir t an in win bor sicg lf th themandlrning day aiat yale beain jhckvincqvidg"zlxk'n wxwDv'
    
    opened by rafaelvalle 17
  • Missmatch between loss in paper and code

    Missmatch between loss in paper and code

    I have few questions:

    • According eq (2) and pseudo-code line 6, one should maximize errD, but the code seems to be minimizing it.
    • Similarly in pseudo-code line 10, one should minimize -errG, but the code seems to be minimizing errG instead.

    Maybe I'm missing something about how the losses are computed and optimized.

    opened by sguada 10
  • cifar10 result not good as expect !

    cifar10 result not good as expect !

    I run your code in cifar10, but the result seems not as good as our expected.

    • system information:

    system: debian 8
    python: python2 pytorch: torch==0.3.1

    • run command:
    $python main.py --dataset cifar10 --dataroot ~/.torch/datasets --cuda 
    
    • output part:
    [24/25][735/782][3335] Loss_D: -1.287177 Loss_G: 0.642245 Loss_D_real: -0.651701 Loss_D_fake 0.635477
    [24/25][740/782][3336] Loss_D: -1.269792 Loss_G: 0.621307 Loss_D_real: -0.657210 Loss_D_fake 0.612582
    [24/25][745/782][3337] Loss_D: -1.250543 Loss_G: 0.636843 Loss_D_real: -0.667046 Loss_D_fake 0.583497
    [24/25][750/782][3338] Loss_D: -1.196252 Loss_G: 0.589907 Loss_D_real: -0.606480 Loss_D_fake 0.589772
    [24/25][755/782][3339] Loss_D: -1.189609 Loss_G: 0.564263 Loss_D_real: -0.612895 Loss_D_fake 0.576714
    [24/25][760/782][3340] Loss_D: -1.178156 Loss_G: 0.586755 Loss_D_real: -0.600268 Loss_D_fake 0.577888
    [24/25][765/782][3341] Loss_D: -1.087157 Loss_G: 0.508717 Loss_D_real: -0.522565 Loss_D_fake 0.564592
    [24/25][770/782][3342] Loss_D: -1.092081 Loss_G: 0.674212 Loss_D_real: -0.657483 Loss_D_fake 0.434598
    [24/25][775/782][3343] Loss_D: -0.937950 Loss_G: 0.209016 Loss_D_real: -0.310877 Loss_D_fake 0.627073
    [24/25][780/782][3344] Loss_D: -1.316574 Loss_G: 0.653665 Loss_D_real: -0.693675 Loss_D_fake 0.622899
    [24/25][782/782][3345] Loss_D: -1.222763 Loss_G: 0.558372 Loss_D_real: -0.567426 Loss_D_fake 0.655337
    

    fake_samples_500.png fake_samples_500

    fake_samples_1000.png fake_samples_1000

    fake_samples_1500.png fake_samples_1500

    fake_samples_2000.png fake_samples_2000

    fake_samples_2500.png fake_samples_2500

    fake_samples_3000.png fake_samples_3000

    Note that this is real_samples.png!!! real_samples

    opened by zyoohv 7
  • Batchnorm Scaling Factor is Clamped Near Zero

    Batchnorm Scaling Factor is Clamped Near Zero

    I believe that the parameter clamping also reduces the batchnorm scaling factor to near zero, when (as far as I understand) it should stay near 1 (where it was initialized).

    In line 170 of main.py:

    # clamp parameters to a cube
    for p in netD.parameters():
        p.data.clamp_(opt.clamp_lower, opt.clamp_upper)
    
    opened by whuffman 7
  • Question about 'one' and 'mone'

    Question about 'one' and 'mone'

    I just got confusion about the "one" and "mone" used in "backward()". The goal of discriminator "D" is to maximize the output of netD for real data, but minimize the output of netD for fake data (as I understand it). While "backward" and "optimizerD" are minimizing a loss function. Therefore, I think "errD_real.backward(one)" should be "errD_real.backward(mone)", and "errD_fake.backward(mone)" should be "errD_fake.backward(one)". That is my opinion. Could you give some explanations? Thank you!

    opened by sz128 4
  • Typo in paper?

    Typo in paper?

    I am not sure if I miss something.

    Looking at the source gives:

    Looking at the paper gives:

    • update D by f(real) - f(g(prior samples)) in Algorithm 1 line 5
    • update G by - f(g(prior samples)) in Algorithm 1 line 10

    Is the minus sign correct in Algorithm 1 line 10?

    opened by PatWie 4
  • Quetions about the D out

    Quetions about the D out

    D for real images output is always negative, for generated images, the output is always positive. Why does this happen? And, I find that training for a long time later, D's Loss tends to 0, and G's Loss also tends to 0. I think if D and G to achieve game balance, G's Loss should tend to 0.5. How do you see this question? Thanks very much

    opened by yichuan9527 3
  • resize problem

    resize problem

    Hello! I run your WGAN's code. I get this bug input.resize_as_(real_cpu).copy_(real_cpu) TypeError: resize_as_ received an invalid combination of arguments - got (!torch.FloatTensor!), but expected (torch.cuda.FloatTensor template) How to solve it?

    opened by yichuan9527 3
  • Why did not tell the label to the discriminator

    Why did not tell the label to the discriminator

    errD_real = netD(inputv) errD_real.backward(one)

    errD_fake = netD(inputv) errD_fake.backward(mone)

    when we train discriminator, why don not tell the label to the discriminator

    opened by chenyuhaosuai 0
  • How can I use a loss as the stopping criteria in Wasserstein GAN?

    How can I use a loss as the stopping criteria in Wasserstein GAN?

    Throughout the training step, the variation of generator loss and critic loss for 1000 epochs as following: a Does this variation look like correct? How can I use a loss as the stopping criteria in Wasserstein GAN? Can I use Generator loss or critic loss? Can I use the early stopping method? Thanks

    opened by RamziFsm 0
  • Interpreting Generator and Critic loss

    Interpreting Generator and Critic loss

    Dear @martinarjovsky, I am currently working on a project with MRI data. I was using WGAN -GP loss on 2D implementation, with hyperparameters proposed in WGAN-GP paper - everything worked smoothly. Now I switched to 3D implementation and started facing issues. The G loss explodes to extremely high values(10^7), while D loss goes really low(-10^6). I understand that for WGAN to work the critic needs to be near optima. However if done so, the Critic keeps producing high output for fake images which makes G loss skyrocket. My patch size is (176,144,16), in 2d it was (176,144). 1)I tried adding layer normalization to Critic, even though the loss values do not explode, the GAN fails to converge. 2) I tried tinkering the learning rate. 2.1.) High learning obviously make it even worse 2.2.) With low learning rates this explosion still happens but later in training. 3) I tried changing number of C iterations 3.1.) The more of Critic iteration I do - the faster it skyrockets. 3.2.) If i do same number of Critic/Generator iterations(1:1) the loss stays in normal margins, but the net does not converge to anything reasonable. Any idea what could be the cause? Thank you!

    wganlooo

    opened by KhrystynaFaryna 1
  • should the gamma and beta on batchnormalization layer be clipped?

    should the gamma and beta on batchnormalization layer be clipped?

    There are two questions. firstly, dose the discriminator have batchnormalization layer? secondly, if so, should the trainable parameters on batchnorm layer be clipped?

    Thanks in advance.

    opened by dawnleft 0
  • No sigmoid activation for G on MLP?

    No sigmoid activation for G on MLP?

    In vanilla GANs, a sigmoid activation is applied on the output layer for G and D. See https://github.com/goodfeli/adversarial/blob/master/mnist.yaml

    For WGAN, there is none for D, and we get a score instead of a probability.

    However, in the code there is no activation (e.g sigmoid) for WGAN-MLP also for G, whereas there is tanh for WGAN-DCGAN. Is there a specific reason?

    Thanks in advance

    opened by druzkaya 0
Owner
null
An implementation of the [Hierarchical (Sig-Wasserstein) GAN] algorithm for large dimensional Time Series Generation

Hierarchical GAN for large dimensional financial market data Implementation This repository is an implementation of the [Hierarchical (Sig-Wasserstein

null 11 Nov 29, 2022
Open source repository for the code accompanying the paper 'Non-Rigid Neural Radiance Fields Reconstruction and Novel View Synthesis of a Deforming Scene from Monocular Video'.

Non-Rigid Neural Radiance Fields This is the official repository for the project "Non-Rigid Neural Radiance Fields: Reconstruction and Novel View Synt

Facebook Research 296 Dec 29, 2022
Code accompanying our paper Feature Learning in Infinite-Width Neural Networks

Empirical Experiments in "Feature Learning in Infinite-width Neural Networks" This repo contains code to replicate our experiments (Word2Vec, MAML) in

Edward Hu 37 Dec 14, 2022
Official repository with code and data accompanying the NAACL 2021 paper "Hurdles to Progress in Long-form Question Answering" (https://arxiv.org/abs/2103.06332).

Hurdles to Progress in Long-form Question Answering This repository contains the official scripts and datasets accompanying our NAACL 2021 paper, "Hur

Kalpesh Krishna 41 Nov 8, 2022
PyTorch code accompanying our paper on Maximum Entropy Generators for Energy-Based Models

Maximum Entropy Generators for Energy-Based Models All experiments have tensorboard visualizations for samples / density / train curves etc. To run th

Rithesh Kumar 135 Oct 27, 2022
Code accompanying the paper "How Tight Can PAC-Bayes be in the Small Data Regime?"

How Tight Can PAC-Bayes be in the Small Data Regime? This is the code to reproduce all experiments for the following paper: @inproceedings{Foong:2021:

null 5 Dec 21, 2021
Code repository accompanying the paper "On Adversarial Robustness: A Neural Architecture Search perspective"

On Adversarial Robustness: A Neural Architecture Search perspective Preparation: Clone the repository: https://github.com/tdchaitanya/nas-robustness.g

Chaitanya Devaguptapu 4 Nov 10, 2022
Codes accompanying the paper "Learning Nearly Decomposable Value Functions with Communication Minimization" (ICLR 2020)

NDQ: Learning Nearly Decomposable Value Functions with Communication Minimization Note This codebase accompanies paper Learning Nearly Decomposable Va

Tonghan Wang 69 Nov 26, 2022
Datasets accompanying the paper ConditionalQA: A Complex Reading Comprehension Dataset with Conditional Answers.

ConditionalQA Datasets accompanying the paper ConditionalQA: A Complex Reading Comprehension Dataset with Conditional Answers. Disclaimer This dataset

null 2 Oct 14, 2021
Code accompanying "Dynamic Neural Relational Inference" from CVPR 2020

Code accompanying "Dynamic Neural Relational Inference" This codebase accompanies the paper "Dynamic Neural Relational Inference" from CVPR 2020. This

Colin Graber 48 Dec 23, 2022
Code accompanying "Learning What To Do by Simulating the Past", ICLR 2021.

Learning What To Do by Simulating the Past This repository contains code that implements the Deep Reward Learning by Simulating the Past (Deep RSLP) a

Center for Human-Compatible AI 24 Aug 7, 2021
This repository contains the accompanying code for Deep Virtual Markers for Articulated 3D Shapes, ICCV'21

Deep Virtual Markers This repository contains the accompanying code for Deep Virtual Markers for Articulated 3D Shapes, ICCV'21 Getting Started Get sa

KimHyomin 45 Oct 7, 2022
Collection of NLP model explanations and accompanying analysis tools

Thermostat is a large collection of NLP model explanations and accompanying analysis tools. Combines explainability methods from the captum library wi

null 126 Nov 22, 2022
The LaTeX and Python code for generating the paper, experiments' results and visualizations reported in each paper is available (whenever possible) in the paper's directory

This repository contains the software implementation of most algorithms used or developed in my research. The LaTeX and Python code for generating the

João Fonseca 3 Jan 3, 2023
Inference code for "StylePeople: A Generative Model of Fullbody Human Avatars" paper. This code is for the part of the paper describing video-based avatars.

NeuralTextures This is repository with inference code for paper "StylePeople: A Generative Model of Fullbody Human Avatars" (CVPR21). This code is for

Visual Understanding Lab @ Samsung AI Center Moscow 18 Oct 6, 2022
This is the official source code for SLATE. We provide the code for the model, the training code, and a dataset loader for the 3D Shapes dataset. This code is implemented in Pytorch.

SLATE This is the official source code for SLATE. We provide the code for the model, the training code and a dataset loader for the 3D Shapes dataset.

Gautam Singh 66 Dec 26, 2022
Code for paper ECCV 2020 paper: Who Left the Dogs Out? 3D Animal Reconstruction with Expectation Maximization in the Loop.

Who Left the Dogs Out? Evaluation and demo code for our ECCV 2020 paper: Who Left the Dogs Out? 3D Animal Reconstruction with Expectation Maximization

Benjamin Biggs 29 Dec 28, 2022
TensorFlow code for the neural network presented in the paper: "Structural Language Models of Code" (ICML'2020)

SLM: Structural Language Models of Code This is an official implementation of the model described in: "Structural Language Models of Code" [PDF] To ap

null 73 Nov 6, 2022