Code for reproducing experiments in "Improved Training of Wasserstein GANs"

Overview

Improved Training of Wasserstein GANs

Code for reproducing experiments in "Improved Training of Wasserstein GANs".

Prerequisites

  • Python, NumPy, TensorFlow, SciPy, Matplotlib
  • A recent NVIDIA GPU

Models

Configuration for all models is specified in a list of constants at the top of the file. Two models should work "out of the box":

  • python gan_toy.py: Toy datasets (8 Gaussians, 25 Gaussians, Swiss Roll).
  • python gan_mnist.py: MNIST

For the other models, edit the file to specify the path to the dataset in DATA_DIR before running. Each model's dataset is publicly available; the download URL is in the file.

  • python gan_64x64.py: 64x64 architectures (this code trains on ImageNet instead of LSUN bedrooms in the paper)
  • python gan_language.py: Character-level language model
  • python gan_cifar.py: CIFAR-10
Comments
  • WGan-gp test in the Celeba dataset.

    WGan-gp test in the Celeba dataset.

    I test the wgan-gp in the celeba dataset. But the quality of the generative images is worse than the original dcgan. and i just change the below code in the basic of w-gan using dcgan generator and discirmator.

    
    #gradient penalty
    differences = self.fake_images - self.images
     alpha = tf.random_uniform(shape=[self.batch_size, 1], minval=0., maxval=1.)
     interpolates = self.images + (alpha*differences)
    gradients = tf.gradients(self.critic(interpolates, True), [interpolates])[0]
     ##2 norm
     slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
     gradient_penalty = tf.reduce_mean((slopes - 1.)**2)
    

    And the reason?

    opened by zhangqianhui 28
  • Problems with Replacing ReLU with eLU

    Problems with Replacing ReLU with eLU

    Hi I have been messing around with the Repo and I have lately been experimenting with switching out the relu activations in the gan_cifar.py with elu activations, however even with varying the lambda value I have not been able to get any convergence. I am wondering if elu activations pose theoretical issues that are not compatible with the wgan-gp (i.e. more non-linear and wider variance in slope values than reLU or leaky reLU), or if elu should be able to work with the wgan-gp (i.e. has your team gotten any models running that used elu activations). Thank you!

    opened by rkjones4 11
  • Conditioning Generator with label information

    Conditioning Generator with label information

    Thank you for sharing the code. Can you please provide insights of Supervised WGAN with label input:

    1. how is generator conditioned with label information? There is no one-hot label vector concat to the latent variable input. The label information is only used at the Conditional batch norm of the generator.

    2. At the inference time, how do you force the Generator to produce certain class image? Where does the class input is used in the generator network?

    opened by ghost 7
  • Poor results in WGAN mode on CelebA

    Poor results in WGAN mode on CelebA

    Hi,

    I'm trying to train on CelebA (cropped and resized to 64x64). The results in WGAN-GP mode look great, both in quality and diversity, however, when I set the mode to 'wgan', I get very distorted faces even after 200K iterations. Any ideas?

    celeba_samples_102799

    Thanks, Eitan

    opened by eitanrich 6
  • "python gan_64x64.py" met errors

    I've downloaded ImageNet small dataset (train_64x64.tar and valid_64x64.tar) and modified DATA_DIR in gan_64x64.py. I've also fixed a potential bug at line 116 (lib.concat -> tf.concat). But I still got the following error:

    Traceback (most recent call last):
      File "gan_64x64.py", line 477, in <module>
        fake_data = Generator(BATCH_SIZE/len(DEVICES))
      File "gan_64x64.py", line 210, in GoodGenerator
        output = ResidualBlock('Generator.Res3', 2*dim, 2*dim, 3, output, resample='up')
      File "gan_64x64.py", line 186, in ResidualBlock
        he_init=False, biases=True, inputs=inputs)
      File "gan_64x64.py", line 120, in UpsampleConv
        output = lib.ops.conv2d.Conv2D(name, input_dim, output_dim, filter_size, output, he_init=he_init, biases=biases)
      File "/data1/home/weixue/cv/gan/improved_wgan_training/tflib/ops/conv2d.py", line 111, in Conv2D
        data_format='NCHW'
      File "/usr/lib/python2.7/site-packages/tensorflow/python/ops/gen_nn_ops.py", line 396, in conv2d
        data_format=data_format, name=name)
      File "/usr/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op
        op_def=op_def)
      File "/usr/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2329, in create_op
        set_shapes_for_outputs(ret)
      File "/usr/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1717, in set_shapes_for_outputs
        shapes = shape_func(op)
      File "/usr/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1667, in call_with_requiring
        return call_cpp_shape_fn(op, require_shape_fn=True)
      File "/usr/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 610, in call_cpp_shape_fn
        debug_python_shape_fn, require_shape_fn)
      File "/usr/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 675, in _call_cpp_shape_fn_impl
        raise ValueError(err.message)
    ValueError: Dimensions must be equal, but are 256 and 128 for 'Generator.Res3.Shortcut/Conv2D' (op: 'Conv2D') with input shapes: [64,256,32,32], [1,1,128,128].
    

    It seems that the source code is still envolving. Is git "master" in a runnable state?

    opened by simonxue 6
  • How to interpret the losses?

    How to interpret the losses?

    When I tried wgan-gp on my own problems, sometimes I got very unbalanced losses (e.g. the loss of discriminator is high, but the loss of generator is around 0. See this). What does this mean? Does it mean the generator is too good?

    opened by wchen342 5
  • Number of critic iterations

    Number of critic iterations

    I am working on a 2D case similar to your toy examples but with a more complex distribution. I noticed big improvements in the contours (i.e. the energy surface learned by the discriminator) when increasing the critic iterations from 5 to 50.

    I really think that 5 critic iterations is too low. I see you also use 5 iterations in the other examples like CIFAR and MNIST and is not showing the full potential of the network. The iterator should be given more time to converge.

    After only 400 generator iterations I am already getting better results than the reported results in the paper for the swiss roll download

    opened by stefdoerr 5
  • The results for LSUN bedrooms 128*128.

    The results for LSUN bedrooms 128*128.

    I found your code used by training the LSUN bedrooms128*128 in your paper, but cannot reproduce the results. The data link in your issue #30 is a "ILSVRC2012_128.tar" one but not the bedrooms images. I'm wondering what data pre-processing you used in the original LSUN bedrooms dataset? Mine is just use center clip [center-64,center+64] as most paper did. BTW, did you you the whole LSUN bedroom dataset for training? Could you please provide the images or the detailed pre-processing method that you use for reproducing the bedroom results in your paper? Thanks so much!

    opened by biuyq 4
  • Potential inconsistencies in calculation of gradient penalty between code and ArXiv paper

    Potential inconsistencies in calculation of gradient penalty between code and ArXiv paper

    I could be wrong, but it seems like the calculation for the gradient penalty is not the same across different code examples in this repo. In the paper, I believe the calculation is shown in line 6 in Algorithm 1 (page 4 in ArXiv paper) -- that line suggests the second of the 2 options is correct. However, most code examples seem to use the first option below.

    Option 1

    In gan_mnist.py (Line 143-144), gan_64x64.py (495-496), gan_language.py (104-105), gan_cifar.py (130-131), and gan_cifar_resnet.py (260-261:

    differences = fake_data - real_data
    interpolates = real_data + (alpha*differences)
    
    # After rearranging, equivalent to: 
    # real_data + alpha*fake_data - alpha*real_data
    

    Option 2

    In gan_toy.py (Line 77) and ArXiv paper (Algorithm 1, line 6 on page 4):

    interpolates = alpha*real_data + ((1-alpha)*fake_data)
    
    # After rearranging, equivalent to: 
    # fake_data + alpha*real_data - alpha*fake_data
    

    real_data and fake_data seem to be transposed between the two options. Am I missing something?

    opened by wronk 4
  • How to compute the two-order partial derivative in a non-graph based framework

    How to compute the two-order partial derivative in a non-graph based framework

    I have noticed that this work is implemented by tensorflow where the graph of the gradient can be constructed. I wonder how to compute two-order partial derivative with non-graph based deep-learning framework like torch/pytorch/etc. It seems impossible to optimize the norm of the gradient with these frameworks.

    Anyway, computing the gradient of norm of the gradient involves the dot product of Jacobian matrix and the gradient, thus the computation may be expensive. I wonder the efficiency of improved-wgan in computing the gradient?

    opened by zsffq999 4
  • A question about the structure of resnet

    A question about the structure of resnet

    hi,thx for your code. I have a question about the structure of resnet.I find that residual block's output is shortcut + (0.3*output) instead of shortcut + output.Is there any theoretical basis for it?Or it is a Experimental conclusion.It is not the same as the original resnet.

    And the code is easy to read,but There is a place I do not understand : gen_64x64.py line 530 _dev_disc_cost = session.run(disc_cost, feed_dict={all_real_data_conv: _data}).Is it should be _dev_disc_cost = session.run(disc_cost, feed_dict={all_real_data_conv: images}). thx

    opened by mathfinder 3
  • Query: WGAN-GP FID SCORE (PyTorch)

    Query: WGAN-GP FID SCORE (PyTorch)

    Thank you for sharing the implementations of the GAN based models on popular datasets like CelebA. I have implemented the WGAN-GP model (in PyTorch), the samples are looking closer the reported work (please refer to the attached image). When I try to evaluate the Fréchet inception distance (FID score), I am not able to comprehend the high values of 100+ (best value 113.4). Others have reported lower FID Scores. Authors in the Quality Aware Generative Adversarial Networks compared the vairous FID scores for various GAN models and there Ishaan Gulrajani's official implementation of WGAN-GP got FID score of 12.89.

    I request you to guide me. Regards Prabhav

    I have used the following repositories for reference while implementing the WGAN-GP model and evaluating the FID scores: LynnHo/DCGAN-LSGAN-WGAN-GP-DRAGAN-Pytorch CharlesNord/WGAN-GP-DRAGAN-Celeba-Pytorch eriklindernoren/PyTorch-GAN hukkelas/pytorch-frechet-inception-distance mseitzer/pytorch-fid

    opened by KomputerMaster64 0
  • gan_mnist.py's ERROR

    gan_mnist.py's ERROR

    when i python gan_mnist.py show some errors,can anyone solve the problem?Thank you very much. Traceback (most recent call last): File "gan_mnist.py", line 107, in <module> fake_data = Generator(BATCH_SIZE) File "gan_mnist.py", line 68, in Generator output = lib.ops.deconv2d.Deconv2D('Generator.2', 4*DIM, 2*DIM, 5, output) File "C:\Users\Tony\Downloads\improved_wgan_training\tflib\ops\deconv2d.py", line 102, in Deconv2D padding='SAME' File "D:\Anaconda3\envs\mytf\lib\site-packages\tensorflow\python\util\traceback_utils.py", line 153, in error_handler raise e.with_traceback(filtered_tb) from None File "D:\Anaconda3\envs\mytf\lib\site-packages\tensorflow\python\util\dispatch.py", line 1170, in op_dispatch_handler result = api_dispatcher.Dispatch(args, kwargs) TypeError: Got an unexpected keyword argument 'value'

    opened by yihangzhao 1
  • If I intend to calculate gradient penalty for two dataset in differet dimension, what should I do?

    If I intend to calculate gradient penalty for two dataset in differet dimension, what should I do?

    My GAN will produce two data with a different distribution (for eample, 20300->2010), therefore, is there any idea to calculate the gradient penalty for this part? Thanks.

    opened by HelloWorldLTY 0
  • Critic loss curve

    Critic loss curve

    Hi, (1) Critic loss curve which should go to 0 will be including gradient penalty or without it? (2) What should be the behavior of gradient penalty(Decreasing towards 0 or something else)? (3) The result will be the same if we do backward propagation of gradient penalty individual or with discriminator loss as below. (i) gradient_penalty.backward(retain_graph=True) [ Individual ] (ii) loss_D = (- loss_real + loss_fake) + gradient_penalty [ with discriminator loss ] loss_D.backward()

    opened by CBD88 0
  • how to run it?

    how to run it?

    Hello, I'm a beginner. How to configure the initial environment required to run this improved_wgan_training? I want to run your code and learn it. Can you provide a detailed installation and configuration tutorial? Thank you very much!

    opened by 524815200 0
Owner
Ishaan Gulrajani
Ishaan Gulrajani
Code for reproducing our analysis in the paper titled: Image Cropping on Twitter: Fairness Metrics, their Limitations, and the Importance of Representation, Design, and Agency

Image Crop Analysis This is a repo for the code used for reproducing our Image Crop Analysis paper as shared on our blog post. If you plan to use this

Twitter Research 239 Jan 2, 2023
This repository contains the source code and data for reproducing results of Deep Continuous Clustering paper

Deep Continuous Clustering Introduction This is a Pytorch implementation of the DCC algorithms presented in the following paper (paper): Sohil Atul Sh

Sohil Shah 197 Nov 29, 2022
Reproducing code of hair style replacement method from Barbershorp.

Barbershorp Reproducing code of hair style replacement method from Barbershorp. Also reproduces II2S, an improved version of Image2StyleGAN. Requireme

null 1 Dec 24, 2021
Repository for reproducing `Model-Based Robust Deep Learning`

Model-Based Robust Deep Learning (MBRDL) In this repository, we include the code necessary for reproducing the code used in Model-Based Robust Deep Le

Alex Robey 16 Sep 19, 2022
Pytorch implementation for reproducing StackGAN_v2 results in the paper StackGAN++: Realistic Image Synthesis with Stacked Generative Adversarial Networks

StackGAN-v2 StackGAN-v1: Tensorflow implementation StackGAN-v1: Pytorch implementation Inception score evaluation Pytorch implementation for reproduci

Han Zhang 809 Dec 16, 2022
PySlowFast: video understanding codebase from FAIR for reproducing state-of-the-art video models.

PySlowFast PySlowFast is an open source video understanding codebase from FAIR that provides state-of-the-art video classification models with efficie

Meta Research 5.3k Jan 3, 2023
Applications using the GTN library and code to reproduce experiments in "Differentiable Weighted Finite-State Transducers"

gtn_applications An applications library using GTN. Current examples include: Offline handwriting recognition Automatic speech recognition Installing

Facebook Research 68 Dec 29, 2022
Code to reproduce the experiments in the paper "Transformer Based Multi-Source Domain Adaptation" (EMNLP 2020)

Transformer Based Multi-Source Domain Adaptation Dustin Wright and Isabelle Augenstein To appear in EMNLP 2020. Read the preprint: https://arxiv.org/a

CopeNLU 36 Dec 5, 2022
Code to run experiments in SLOE: A Faster Method for Statistical Inference in High-Dimensional Logistic Regression.

Code to run experiments in SLOE: A Faster Method for Statistical Inference in High-Dimensional Logistic Regression. Not an official Google product. Me

Google Research 27 Dec 12, 2022
Code to reproduce experiments in the paper "Explainability Requires Interactivity".

Explainability Requires Interactivity This repository contains the code to train all custom models used in the paper Explainability Requires Interacti

Digital Health & Machine Learning 5 Apr 7, 2022
PyTorch code to run synthetic experiments.

Code repository for Invariant Risk Minimization Source code for the paper: @article{InvariantRiskMinimization, title={Invariant Risk Minimization}

Facebook Research 345 Dec 12, 2022
Code to reproduce the experiments from our NeurIPS 2021 paper " The Limitations of Large Width in Neural Networks: A Deep Gaussian Process Perspective"

Code To run: python runner.py new --save <SAVE_NAME> --data <PATH_TO_DATA_DIR> --dataset <DATASET> --model <model_name> [options] --n 1000 - train - t

Geoff Pleiss 5 Dec 12, 2022
Source code and notebooks to reproduce experiments and benchmarks on Bias Faces in the Wild (BFW).

Face Recognition: Too Bias, or Not Too Bias? Robinson, Joseph P., Gennady Livitz, Yann Henon, Can Qin, Yun Fu, and Samson Timoner. "Face recognition:

Joseph P. Robinson 41 Dec 12, 2022
Minimal diffusion models - Minimal code and simple experiments to play with Denoising Diffusion Probabilistic Models (DDPMs)

Minimal code and simple experiments to play with Denoising Diffusion Probabilist

Rithesh Kumar 16 Oct 6, 2022
The LaTeX and Python code for generating the paper, experiments' results and visualizations reported in each paper is available (whenever possible) in the paper's directory

This repository contains the software implementation of most algorithms used or developed in my research. The LaTeX and Python code for generating the

João Fonseca 3 Jan 3, 2023
SciKit-Learn Laboratory (SKLL) makes it easy to run machine learning experiments.

SciKit-Learn Laboratory This Python package provides command-line utilities to make it easier to run machine learning experiments with scikit-learn. O

ETS 528 Nov 25, 2022
Simple reimplemetation experiments about FcaNet

FcaNet-CIFAR An implementation of the paper FcaNet: Frequency Channel Attention Networks on CIFAR10/CIFAR100 dataset. how to run Code: python Cifar.py

null 76 Feb 4, 2021
Algorithmic trading with deep learning experiments

Deep-Trading Algorithmic trading with deep learning experiments. Now released part one - simple time series forecasting. I plan to implement more soph

Alex Honchar 1.4k Jan 2, 2023
This repository contains the implementations related to the experiments of a set of publicly available datasets that are used in the time series forecasting research space.

TSForecasting This repository contains the implementations related to the experiments of a set of publicly available datasets that are used in the tim

Rakshitha Godahewa 80 Dec 30, 2022