Code for paper "Which Training Methods for GANs do actually Converge? (ICML 2018)"

Overview

GAN stability

This repository contains the experiments in the supplementary material for the paper Which Training Methods for GANs do actually Converge?.

To cite this work, please use

@INPROCEEDINGS{Mescheder2018ICML,
  author = {Lars Mescheder and Sebastian Nowozin and Andreas Geiger},
  title = {Which Training Methods for GANs do actually Converge?},
  booktitle = {International Conference on Machine Learning (ICML)},
  year = {2018}
}

You can find further details on our project page.

Usage

First download your data and put it into the ./data folder.

To train a new model, first create a config script similar to the ones provided in the ./configs folder. You can then train you model using

python train.py PATH_TO_CONFIG

To compute the inception score for your model and generate samples, use

python test.py PATH_TO_CONFIG

Finally, you can create nice latent space interpolations using

python interpolate.py PATH_TO_CONFIG

or

python interpolate_class.py PATH_TO_CONFIG

Pretrained models

We also provide several pretrained models.

You can use the models for sampling by entering

python test.py PATH_TO_CONFIG

where PATH_TO_CONFIG is one of the config files

configs/pretrained/celebA_pretrained.yaml
configs/pretrained/celebAHQ_pretrained.yaml
configs/pretrained/imagenet_pretrained.yaml
configs/pretrained/lsun_bedroom_pretrained.yaml
configs/pretrained/lsun_bridge_pretrained.yaml
configs/pretrained/lsun_church_pretrained.yaml
configs/pretrained/lsun_tower_pretrained.yaml

Our script will automatically download the model checkpoints and run the generation. You can find the outputs in the output/pretrained folders. Similarly, you can use the scripts interpolate.py and interpolate_class.py for generating interpolations for the pretrained models.

Please note that the config files *_pretrained.yaml are only for generation, not for training new models: when these configs are used for training, the model will be trained from scratch, but during inference our code will still use the pretrained model.

Notes

  • Batch normalization is currently not supported when using an exponential running average, as the running average is only computed over the parameters of the models and not the other buffers of the model.

Results

celebA-HQ

celebA-HQ

Imagenet

Imagenet 0 Imagenet 1 Imagenet 2 Imagenet 3 Imagenet 4

Comments
  • Instance noise used in all experiments?

    Instance noise used in all experiments?

    I noticed in the code for inputs.py that there appears to be instance noise applied universally to all training examples (along with other data augmentation). Was this and the other data augmentation used for all experimental results in the paper? I don't remember it being mentioned. I was surprised to see this as instance noise was one of the regularization approaches you were comparing eg the R1 approaches to - I didn't realize both were being used simultaneously. Or is the scale of the noise injected here (uniform on [0-1/128)) much lower than required for use as a regularizer?

    opened by zplizzi 4
  • Minimal Dirac-GAN

    Minimal Dirac-GAN

    Hello @LMescheder , thanks for the nice repo. I'd like to create a minimal Dirac-GAN example. Ideally, I'd want a side-by-side comparison with and without the proposed regularization. This is what I have so far:

    import torch as th
    from torch import nn
    import numpy as np
    from tqdm import tqdm
    import os
    from matplotlib import pyplot as plt
    
    
    class Generator(nn.Module):
      def __init__(self, z_dim=1, gen_dim=1):
        super(Generator, self).__init__()
        self.z_dim = 1
        self.h1 = nn.Linear(z_dim, gen_dim, bias=False)
    
      def forward(self, x):
        x = self.h1(x)
        return x
    
    
    class Discriminator(nn.Module):
      def __init__(self, gen_dim=1):
        super(Discriminator, self).__init__()
        self.h1 = nn.Linear(gen_dim, 1, bias=False)
        self.sigmoid = nn.Sigmoid()
    
      def forward(self, x):
        x = self.sigmoid(self.h1(x))
        return x
    
    
    def binary_crossentropy(y_pred, y_true):
      y_pred = th.clamp(y_pred, 1e-7, 1. - 1e-7)
      bce = -y_true * th.log(y_pred) - (1. - y_true) * th.log(1. - y_pred)
      return th.mean(bce)
    
    
    def generator_real(batch_size):
      while True:
        yield th.randn(batch_size, 1, requires_grad=True, dtype=th.float32)
    
    
    def plot_figure(theta, iteration):
      delta = 1. / 300
      x = np.arange(-1., 1., delta)
      y = np.zeros_like(x)
      y[int(theta * 300) + 300] = 1
      plt.cla()
      plt.plot(x, y)
      plt.yticks([])
      plt.xlabel(r'$\theta$')
      plt.grid('off')
      # plt.savefig('/tmp/gif/tmp_%08d.jpg' % (iteration))
      plt.pause(0.01)
    
    
    def main():
      if not os.path.exists('/tmp/gif'):
        os.makedirs('/tmp/gif')
      use_reg = False
      gamma = 1.0
      real_label = 1
      fake_label = 0
      epochs = 5
      steps_per_epoch = 100
      batch_size = 128
      generator = Generator()
      # set the one generator weight
      generator.h1.weight = th.nn.Parameter(th.tensor([[0.7]], dtype=th.float32))
      discriminator = Discriminator()
      # set the one discriminator weight
      discriminator.h1.weight = th.nn.Parameter(th.tensor([[0.7]], dtype=th.float32))
      real_data_gen = generator_real(batch_size)
    
      g_optimizer = th.optim.SGD(generator.parameters(), lr=0.1, momentum=0.0)
      d_optimizer = th.optim.SGD(discriminator.parameters(), lr=0.1, momentum=0.0)
    
      for epoch in range(1, epochs + 1):
        for step in tqdm(range(steps_per_epoch)):
          # Save figure
          plot_figure(discriminator.h1.weight.item(), iteration=epoch * steps_per_epoch + step)
          # 1.) Update the discriminator
    
          # a.) real data
          d_optimizer.zero_grad()
          X_real = next(real_data_gen)
          y_pred = discriminator(X_real)
          real_loss = binary_crossentropy(y_pred, real_label * th.ones_like(y_pred))
          if use_reg:
            real_loss += 0.5 * gamma * discriminator.h1.weight.pow(2).item()
          real_loss.backward()
    
          # b.) fake data
          noise = th.randn(batch_size, 1)
          X_fake = generator(noise)
          y_pred = discriminator(X_fake.detach())
          fake_loss = binary_crossentropy(y_pred, fake_label * th.ones_like(y_pred))
          fake_loss.backward()
          d_optimizer.step()
    
          # 2.) Update the generator
    
          g_optimizer.zero_grad()
          y_pred = discriminator(X_fake)
          # use real label in the loss
          fake_loss_if_real = binary_crossentropy(y_pred, real_label * th.ones_like(y_pred))
          fake_loss_if_real.backward()
          g_optimizer.step()
    
    
    if __name__ == '__main__':
      main()
    

    Output: no_reg

    Your animations looked much better (i.e more oscillation without regularization). Is this due to a bug (I think so) or hyperparameters?

    opened by see-- 3
  • 2D-data distributions possible to recreate with code?

    2D-data distributions possible to recreate with code?

    Hello,

    First, thank you very much for your paper, it was very interesting and it helped me a lot! In your paper, in figure 8 and figure 9 in the appendices, you do some testing with 2D data distributions. Is it possible to reproduce these distributions with the code?

    Thank you in advance!

    opened by SimonVerlinden 0
  • Question about Eq. 1 in the article.

    Question about Eq. 1 in the article.

    Hi, I've recently read your paper and I think it is great. I have a question about Eq.1 in the article. I found that the equation could not be led to the original GAN's objective by replacing f with -log(1+exp(-t)). I think the negation of discriminator output in both terms on the right-hand side is switched. Am I correct? image

    opened by yoon28 0
  • Why use different architectures for ImageNet, celeba and church?

    Why use different architectures for ImageNet, celeba and church?

    Hi, thank you for sharing the code. It seems like that you used a large model for ImageNet, and a ralatively small one for the other dataset, e.g. celeba, chruch. Is it true that applying the much complex model (used for ImageNet) on celeba will produce a much better results? I am very curious about your selection of the model for different datasets, can you share it with me? Thank you very much.

    opened by MiaoyunZhao 0
  • It only support Standard GAN loss and WGAN-GP ??

    It only support Standard GAN loss and WGAN-GP ??

    Hi, I'm using gradient penalty with LSGAN loss but it did not work ,so I replace LSGAN with Standard GAN loss it work fine !!

    except WGAN loss which type gan loss it support ??

    RaGAN ?

    opened by Johnson-yue 0
  • Questions with batch normalization

    Questions with batch normalization

    WechatIMG57 Hello, I just added BN in resblock and run 'python train.py configs/celebA-HQ', then it cannot converge at all.(see the picture below)

    00006000

    I noticed the moving average in your training, but I think its influence to BN is quite small. I cannot understand the reason why adding BN cannot converge. Looking forward to your reply!

    opened by LuChengTHU 1
Owner
Lars Mescheder
Lars Mescheder
Inference code for "StylePeople: A Generative Model of Fullbody Human Avatars" paper. This code is for the part of the paper describing video-based avatars.

NeuralTextures This is repository with inference code for paper "StylePeople: A Generative Model of Fullbody Human Avatars" (CVPR21). This code is for

Visual Understanding Lab @ Samsung AI Center Moscow 18 Oct 6, 2022
This is the official source code for SLATE. We provide the code for the model, the training code, and a dataset loader for the 3D Shapes dataset. This code is implemented in Pytorch.

SLATE This is the official source code for SLATE. We provide the code for the model, the training code and a dataset loader for the 3D Shapes dataset.

Gautam Singh 66 Dec 26, 2022
Code for paper ECCV 2020 paper: Who Left the Dogs Out? 3D Animal Reconstruction with Expectation Maximization in the Loop.

Who Left the Dogs Out? Evaluation and demo code for our ECCV 2020 paper: Who Left the Dogs Out? 3D Animal Reconstruction with Expectation Maximization

Benjamin Biggs 29 Dec 28, 2022
TensorFlow code for the neural network presented in the paper: "Structural Language Models of Code" (ICML'2020)

SLM: Structural Language Models of Code This is an official implementation of the model described in: "Structural Language Models of Code" [PDF] To ap

null 73 Nov 6, 2022
Code for the prototype tool in our paper "CoProtector: Protect Open-Source Code against Unauthorized Training Usage with Data Poisoning".

CoProtector Code for the prototype tool in our paper "CoProtector: Protect Open-Source Code against Unauthorized Training Usage with Data Poisoning".

Zhensu Sun 1 Oct 26, 2021
Code to use Augmented Shapiro Wilks Stopping, as well as code for the paper "Statistically Signifigant Stopping of Neural Network Training"

This codebase is being actively maintained, please create and issue if you have issues using it Basics All data files are included under losses and ea

J K Terry 32 Nov 9, 2021
Code for our method RePRI for Few-Shot Segmentation. Paper at http://arxiv.org/abs/2012.06166

Region Proportion Regularized Inference (RePRI) for Few-Shot Segmentation In this repo, we provide the code for our paper : "Few-Shot Segmentation Wit

Malik Boudiaf 138 Dec 12, 2022
Code for ACM MM 2020 paper "NOH-NMS: Improving Pedestrian Detection by Nearby Objects Hallucination"

NOH-NMS: Improving Pedestrian Detection by Nearby Objects Hallucination The offical implementation for the "NOH-NMS: Improving Pedestrian Detection by

Tencent YouTu Research 64 Nov 11, 2022
Official TensorFlow code for the forthcoming paper

~ Efficient-CapsNet ~ Are you tired of over inflated and overused convolutional neural networks? You're right! It's time for CAPSULES :)

Vittorio Mazzia 203 Jan 8, 2023
This is the code for the paper "Contrastive Clustering" (AAAI 2021)

Contrastive Clustering (CC) This is the code for the paper "Contrastive Clustering" (AAAI 2021) Dependency python>=3.7 pytorch>=1.6.0 torchvision>=0.8

Yunfan Li 210 Dec 30, 2022
Code for the paper Learning the Predictability of the Future

Learning the Predictability of the Future Code from the paper Learning the Predictability of the Future. Website of the project in hyperfuture.cs.colu

Computer Vision Lab at Columbia University 139 Nov 18, 2022
PyTorch code for the paper: FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning

FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning This is the PyTorch implementation of our paper: FeatMatch: Feature-Based Augmentat

null 43 Nov 19, 2022
Code for the paper A Theoretical Analysis of the Repetition Problem in Text Generation

A Theoretical Analysis of the Repetition Problem in Text Generation This repository share the code for the paper "A Theoretical Analysis of the Repeti

Zihao Fu 37 Nov 21, 2022
Code for our ICASSP 2021 paper: SA-Net: Shuffle Attention for Deep Convolutional Neural Networks

SA-Net: Shuffle Attention for Deep Convolutional Neural Networks (paper) By Qing-Long Zhang and Yu-Bin Yang [State Key Laboratory for Novel Software T

Qing-Long Zhang 199 Jan 8, 2023
Open source repository for the code accompanying the paper 'Non-Rigid Neural Radiance Fields Reconstruction and Novel View Synthesis of a Deforming Scene from Monocular Video'.

Non-Rigid Neural Radiance Fields This is the official repository for the project "Non-Rigid Neural Radiance Fields: Reconstruction and Novel View Synt

Facebook Research 296 Dec 29, 2022
Code for the Shortformer model, from the paper by Ofir Press, Noah A. Smith and Mike Lewis.

Shortformer This repository contains the code and the final checkpoint of the Shortformer model. This file explains how to run our experiments on the

Ofir Press 138 Apr 15, 2022
PyTorch code for ICLR 2021 paper Unbiased Teacher for Semi-Supervised Object Detection

Unbiased Teacher for Semi-Supervised Object Detection This is the PyTorch implementation of our paper: Unbiased Teacher for Semi-Supervised Object Detection

Facebook Research 366 Dec 28, 2022
Official code for paper "Optimization for Oriented Object Detection via Representation Invariance Loss".

Optimization for Oriented Object Detection via Representation Invariance Loss By Qi Ming, Zhiqiang Zhou, Lingjuan Miao, Xue Yang, and Yunpeng Dong. Th

ming71 56 Nov 28, 2022
Code for our CVPR 2021 paper "MetaCam+DSCE"

Joint Noise-Tolerant Learning and Meta Camera Shift Adaptation for Unsupervised Person Re-Identification (CVPR'21) Introduction Code for our CVPR 2021

FlyingRoastDuck 59 Oct 31, 2022