Improved Training of Wasserstein GANs

Code for reproducing experiments in "Improved Training of Wasserstein GANs".


  • Python, NumPy, TensorFlow, SciPy, Matplotlib
  • A recent NVIDIA GPU


Configuration for all models is specified in a list of constants at the top of the file. Two models should work "out of the box":

  • python Toy datasets (8 Gaussians, 25 Gaussians, Swiss Roll).
  • python MNIST

For the other models, edit the file to specify the path to the dataset in DATA_DIR before running. Each model's dataset is publicly available; the download URL is in the file.

  • python 64x64 architectures (this code trains on ImageNet instead of LSUN bedrooms in the paper)
  • python Character-level language model
  • python CIFAR-10
  • WGan-gp test in the Celeba dataset.

    WGan-gp test in the Celeba dataset.

    I test the wgan-gp in the celeba dataset. But the quality of the generative images is worse than the original dcgan. and i just change the below code in the basic of w-gan using dcgan generator and discirmator.

    #gradient penalty
    differences = self.fake_images - self.images
     alpha = tf.random_uniform(shape=[self.batch_size, 1], minval=0., maxval=1.)
     interpolates = self.images + (alpha*differences)
    gradients = tf.gradients(self.critic(interpolates, True), [interpolates])[0]
     ##2 norm
     slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
     gradient_penalty = tf.reduce_mean((slopes - 1.)**2)

    And the reason?

    opened by zhangqianhui 28
  • Problems with Replacing ReLU with eLU

    Problems with Replacing ReLU with eLU

    Hi I have been messing around with the Repo and I have lately been experimenting with switching out the relu activations in the with elu activations, however even with varying the lambda value I have not been able to get any convergence. I am wondering if elu activations pose theoretical issues that are not compatible with the wgan-gp (i.e. more non-linear and wider variance in slope values than reLU or leaky reLU), or if elu should be able to work with the wgan-gp (i.e. has your team gotten any models running that used elu activations). Thank you!

    opened by rkjones4 11
  • Conditioning Generator with label information

    Conditioning Generator with label information

    Thank you for sharing the code. Can you please provide insights of Supervised WGAN with label input:

    1. how is generator conditioned with label information? There is no one-hot label vector concat to the latent variable input. The label information is only used at the Conditional batch norm of the generator.

    2. At the inference time, how do you force the Generator to produce certain class image? Where does the class input is used in the generator network?

    opened by ghost 7
  • Poor results in WGAN mode on CelebA

    Poor results in WGAN mode on CelebA


    I'm trying to train on CelebA (cropped and resized to 64x64). The results in WGAN-GP mode look great, both in quality and diversity, however, when I set the mode to 'wgan', I get very distorted faces even after 200K iterations. Any ideas?


    Thanks, Eitan

    opened by eitanrich 6
  • "python" met errors

    I've downloaded ImageNet small dataset (train_64x64.tar and valid_64x64.tar) and modified DATA_DIR in I've also fixed a potential bug at line 116 (lib.concat -> tf.concat). But I still got the following error:

    Traceback (most recent call last):
      File "", line 477, in <module>
        fake_data = Generator(BATCH_SIZE/len(DEVICES))
      File "", line 210, in GoodGenerator
        output = ResidualBlock('Generator.Res3', 2*dim, 2*dim, 3, output, resample='up')
      File "", line 186, in ResidualBlock
        he_init=False, biases=True, inputs=inputs)
      File "", line 120, in UpsampleConv
        output = lib.ops.conv2d.Conv2D(name, input_dim, output_dim, filter_size, output, he_init=he_init, biases=biases)
      File "/data1/home/weixue/cv/gan/improved_wgan_training/tflib/ops/", line 111, in Conv2D
      File "/usr/lib/python2.7/site-packages/tensorflow/python/ops/", line 396, in conv2d
        data_format=data_format, name=name)
      File "/usr/lib/python2.7/site-packages/tensorflow/python/framework/", line 763, in apply_op
      File "/usr/lib/python2.7/site-packages/tensorflow/python/framework/", line 2329, in create_op
      File "/usr/lib/python2.7/site-packages/tensorflow/python/framework/", line 1717, in set_shapes_for_outputs
        shapes = shape_func(op)
      File "/usr/lib/python2.7/site-packages/tensorflow/python/framework/", line 1667, in call_with_requiring
        return call_cpp_shape_fn(op, require_shape_fn=True)
      File "/usr/lib/python2.7/site-packages/tensorflow/python/framework/", line 610, in call_cpp_shape_fn
        debug_python_shape_fn, require_shape_fn)
      File "/usr/lib/python2.7/site-packages/tensorflow/python/framework/", line 675, in _call_cpp_shape_fn_impl
        raise ValueError(err.message)
    ValueError: Dimensions must be equal, but are 256 and 128 for 'Generator.Res3.Shortcut/Conv2D' (op: 'Conv2D') with input shapes: [64,256,32,32], [1,1,128,128].

    It seems that the source code is still envolving. Is git "master" in a runnable state?

    opened by simonxue 6
  • How to interpret the losses?

    How to interpret the losses?

    When I tried wgan-gp on my own problems, sometimes I got very unbalanced losses (e.g. the loss of discriminator is high, but the loss of generator is around 0. See this). What does this mean? Does it mean the generator is too good?

    opened by wchen342 5
  • Number of critic iterations

    Number of critic iterations

    I am working on a 2D case similar to your toy examples but with a more complex distribution. I noticed big improvements in the contours (i.e. the energy surface learned by the discriminator) when increasing the critic iterations from 5 to 50.

    I really think that 5 critic iterations is too low. I see you also use 5 iterations in the other examples like CIFAR and MNIST and is not showing the full potential of the network. The iterator should be given more time to converge.

    After only 400 generator iterations I am already getting better results than the reported results in the paper for the swiss roll download

    opened by stefdoerr 5
  • The results for LSUN bedrooms 128*128.

    The results for LSUN bedrooms 128*128.

    I found your code used by training the LSUN bedrooms128*128 in your paper, but cannot reproduce the results. The data link in your issue #30 is a "ILSVRC2012_128.tar" one but not the bedrooms images. I'm wondering what data pre-processing you used in the original LSUN bedrooms dataset? Mine is just use center clip [center-64,center+64] as most paper did. BTW, did you you the whole LSUN bedroom dataset for training? Could you please provide the images or the detailed pre-processing method that you use for reproducing the bedroom results in your paper? Thanks so much!

    opened by biuyq 4
  • Potential inconsistencies in calculation of gradient penalty between code and ArXiv paper

    Potential inconsistencies in calculation of gradient penalty between code and ArXiv paper

    I could be wrong, but it seems like the calculation for the gradient penalty is not the same across different code examples in this repo. In the paper, I believe the calculation is shown in line 6 in Algorithm 1 (page 4 in ArXiv paper) -- that line suggests the second of the 2 options is correct. However, most code examples seem to use the first option below.

    Option 1

    In (Line 143-144), (495-496), (104-105), (130-131), and (260-261:

    differences = fake_data - real_data
    interpolates = real_data + (alpha*differences)
    # After rearranging, equivalent to: 
    # real_data + alpha*fake_data - alpha*real_data

    Option 2

    In (Line 77) and ArXiv paper (Algorithm 1, line 6 on page 4):

    interpolates = alpha*real_data + ((1-alpha)*fake_data)
    # After rearranging, equivalent to: 
    # fake_data + alpha*real_data - alpha*fake_data

    real_data and fake_data seem to be transposed between the two options. Am I missing something?

    opened by wronk 4
  • How to compute the two-order partial derivative in a non-graph based framework

    How to compute the two-order partial derivative in a non-graph based framework

    I have noticed that this work is implemented by tensorflow where the graph of the gradient can be constructed. I wonder how to compute two-order partial derivative with non-graph based deep-learning framework like torch/pytorch/etc. It seems impossible to optimize the norm of the gradient with these frameworks.

    Anyway, computing the gradient of norm of the gradient involves the dot product of Jacobian matrix and the gradient, thus the computation may be expensive. I wonder the efficiency of improved-wgan in computing the gradient?

    opened by zsffq999 4
  • A question about the structure of resnet

    A question about the structure of resnet

    hi,thx for your code. I have a question about the structure of resnet.I find that residual block's output is shortcut + (0.3*output) instead of shortcut + output.Is there any theoretical basis for it?Or it is a Experimental conclusion.It is not the same as the original resnet.

    And the code is easy to read,but There is a place I do not understand : line 530 _dev_disc_cost =, feed_dict={all_real_data_conv: _data}).Is it should be _dev_disc_cost =, feed_dict={all_real_data_conv: images}). thx

    opened by mathfinder 3
  • Query: WGAN-GP FID SCORE (PyTorch)

    Query: WGAN-GP FID SCORE (PyTorch)

    Thank you for sharing the implementations of the GAN based models on popular datasets like CelebA. I have implemented the WGAN-GP model (in PyTorch), the samples are looking closer the reported work (please refer to the attached image). When I try to evaluate the Fréchet inception distance (FID score), I am not able to comprehend the high values of 100+ (best value 113.4). Others have reported lower FID Scores. Authors in the Quality Aware Generative Adversarial Networks compared the vairous FID scores for various GAN models and there Ishaan Gulrajani's official implementation of WGAN-GP got FID score of 12.89.

    I request you to guide me. Regards Prabhav

    I have used the following repositories for reference while implementing the WGAN-GP model and evaluating the FID scores: LynnHo/DCGAN-LSGAN-WGAN-GP-DRAGAN-Pytorch CharlesNord/WGAN-GP-DRAGAN-Celeba-Pytorch eriklindernoren/PyTorch-GAN hukkelas/pytorch-frechet-inception-distance mseitzer/pytorch-fid

    opened by KomputerMaster64 0

    when i python show some errors,can anyone solve the problem?Thank you very much. Traceback (most recent call last): File "", line 107, in <module> fake_data = Generator(BATCH_SIZE) File "", line 68, in Generator output = lib.ops.deconv2d.Deconv2D('Generator.2', 4*DIM, 2*DIM, 5, output) File "C:\Users\Tony\Downloads\improved_wgan_training\tflib\ops\", line 102, in Deconv2D padding='SAME' File "D:\Anaconda3\envs\mytf\lib\site-packages\tensorflow\python\util\", line 153, in error_handler raise e.with_traceback(filtered_tb) from None File "D:\Anaconda3\envs\mytf\lib\site-packages\tensorflow\python\util\", line 1170, in op_dispatch_handler result = api_dispatcher.Dispatch(args, kwargs) TypeError: Got an unexpected keyword argument 'value'

    opened by yihangzhao 1
  • If I intend to calculate gradient penalty for two dataset in differet dimension, what should I do?

    If I intend to calculate gradient penalty for two dataset in differet dimension, what should I do?

    My GAN will produce two data with a different distribution (for eample, 20300->2010), therefore, is there any idea to calculate the gradient penalty for this part? Thanks.

    opened by HelloWorldLTY 0
  • Critic loss curve

    Critic loss curve

    Hi, (1) Critic loss curve which should go to 0 will be including gradient penalty or without it? (2) What should be the behavior of gradient penalty(Decreasing towards 0 or something else)? (3) The result will be the same if we do backward propagation of gradient penalty individual or with discriminator loss as below. (i) gradient_penalty.backward(retain_graph=True) [ Individual ] (ii) loss_D = (- loss_real + loss_fake) + gradient_penalty [ with discriminator loss ] loss_D.backward()

    opened by CBD88 0
  • how to run it?

    how to run it?

    Hello, I'm a beginner. How to configure the initial environment required to run this improved_wgan_training? I want to run your code and learn it. Can you provide a detailed installation and configuration tutorial? Thank you very much!

    opened by 524815200 0
Ishaan Gulrajani
Ishaan Gulrajani
