Code for paper "Which Training Methods for GANs do actually Converge? (ICML 2018)"

Overview

GAN stability

This repository contains the experiments in the supplementary material for the paper Which Training Methods for GANs do actually Converge?.

To cite this work, please use

@INPROCEEDINGS{Mescheder2018ICML,
  author = {Lars Mescheder and Sebastian Nowozin and Andreas Geiger},
  title = {Which Training Methods for GANs do actually Converge?},
  booktitle = {International Conference on Machine Learning (ICML)},
  year = {2018}
}

You can find further details on our project page.

Usage

First download your data and put it into the ./data folder.

To train a new model, first create a config script similar to the ones provided in the ./configs folder. You can then train you model using

python train.py PATH_TO_CONFIG

To compute the inception score for your model and generate samples, use

python test.py PATH_TO_CONFIG

Finally, you can create nice latent space interpolations using

python interpolate.py PATH_TO_CONFIG

or

python interpolate_class.py PATH_TO_CONFIG

Pretrained models

We also provide several pretrained models.

You can use the models for sampling by entering

python test.py PATH_TO_CONFIG

where PATH_TO_CONFIG is one of the config files

configs/pretrained/celebA_pretrained.yaml
configs/pretrained/celebAHQ_pretrained.yaml
configs/pretrained/imagenet_pretrained.yaml
configs/pretrained/lsun_bedroom_pretrained.yaml
configs/pretrained/lsun_bridge_pretrained.yaml
configs/pretrained/lsun_church_pretrained.yaml
configs/pretrained/lsun_tower_pretrained.yaml

Our script will automatically download the model checkpoints and run the generation. You can find the outputs in the output/pretrained folders. Similarly, you can use the scripts interpolate.py and interpolate_class.py for generating interpolations for the pretrained models.

Please note that the config files *_pretrained.yaml are only for generation, not for training new models: when these configs are used for training, the model will be trained from scratch, but during inference our code will still use the pretrained model.

Notes

  • Batch normalization is currently not supported when using an exponential running average, as the running average is only computed over the parameters of the models and not the other buffers of the model.

Results

celebA-HQ

celebA-HQ

Imagenet

Imagenet 0 Imagenet 1 Imagenet 2 Imagenet 3 Imagenet 4

Comments
  • Instance noise used in all experiments?

    Instance noise used in all experiments?

    I noticed in the code for inputs.py that there appears to be instance noise applied universally to all training examples (along with other data augmentation). Was this and the other data augmentation used for all experimental results in the paper? I don't remember it being mentioned. I was surprised to see this as instance noise was one of the regularization approaches you were comparing eg the R1 approaches to - I didn't realize both were being used simultaneously. Or is the scale of the noise injected here (uniform on [0-1/128)) much lower than required for use as a regularizer?

    opened by zplizzi 4
  • Minimal Dirac-GAN

    Minimal Dirac-GAN

    Hello @LMescheder , thanks for the nice repo. I'd like to create a minimal Dirac-GAN example. Ideally, I'd want a side-by-side comparison with and without the proposed regularization. This is what I have so far:

    import torch as th
    from torch import nn
    import numpy as np
    from tqdm import tqdm
    import os
    from matplotlib import pyplot as plt
    
    
    class Generator(nn.Module):
      def __init__(self, z_dim=1, gen_dim=1):
        super(Generator, self).__init__()
        self.z_dim = 1
        self.h1 = nn.Linear(z_dim, gen_dim, bias=False)
    
      def forward(self, x):
        x = self.h1(x)
        return x
    
    
    class Discriminator(nn.Module):
      def __init__(self, gen_dim=1):
        super(Discriminator, self).__init__()
        self.h1 = nn.Linear(gen_dim, 1, bias=False)
        self.sigmoid = nn.Sigmoid()
    
      def forward(self, x):
        x = self.sigmoid(self.h1(x))
        return x
    
    
    def binary_crossentropy(y_pred, y_true):
      y_pred = th.clamp(y_pred, 1e-7, 1. - 1e-7)
      bce = -y_true * th.log(y_pred) - (1. - y_true) * th.log(1. - y_pred)
      return th.mean(bce)
    
    
    def generator_real(batch_size):
      while True:
        yield th.randn(batch_size, 1, requires_grad=True, dtype=th.float32)
    
    
    def plot_figure(theta, iteration):
      delta = 1. / 300
      x = np.arange(-1., 1., delta)
      y = np.zeros_like(x)
      y[int(theta * 300) + 300] = 1
      plt.cla()
      plt.plot(x, y)
      plt.yticks([])
      plt.xlabel(r'$\theta$')
      plt.grid('off')
      # plt.savefig('/tmp/gif/tmp_%08d.jpg' % (iteration))
      plt.pause(0.01)
    
    
    def main():
      if not os.path.exists('/tmp/gif'):
        os.makedirs('/tmp/gif')
      use_reg = False
      gamma = 1.0
      real_label = 1
      fake_label = 0
      epochs = 5
      steps_per_epoch = 100
      batch_size = 128
      generator = Generator()
      # set the one generator weight
      generator.h1.weight = th.nn.Parameter(th.tensor([[0.7]], dtype=th.float32))
      discriminator = Discriminator()
      # set the one discriminator weight
      discriminator.h1.weight = th.nn.Parameter(th.tensor([[0.7]], dtype=th.float32))
      real_data_gen = generator_real(batch_size)
    
      g_optimizer = th.optim.SGD(generator.parameters(), lr=0.1, momentum=0.0)
      d_optimizer = th.optim.SGD(discriminator.parameters(), lr=0.1, momentum=0.0)
    
      for epoch in range(1, epochs + 1):
        for step in tqdm(range(steps_per_epoch)):
          # Save figure
          plot_figure(discriminator.h1.weight.item(), iteration=epoch * steps_per_epoch + step)
          # 1.) Update the discriminator
    
          # a.) real data
          d_optimizer.zero_grad()
          X_real = next(real_data_gen)
          y_pred = discriminator(X_real)
          real_loss = binary_crossentropy(y_pred, real_label * th.ones_like(y_pred))
          if use_reg:
            real_loss += 0.5 * gamma * discriminator.h1.weight.pow(2).item()
          real_loss.backward()
    
          # b.) fake data
          noise = th.randn(batch_size, 1)
          X_fake = generator(noise)
          y_pred = discriminator(X_fake.detach())
          fake_loss = binary_crossentropy(y_pred, fake_label * th.ones_like(y_pred))
          fake_loss.backward()
          d_optimizer.step()
    
          # 2.) Update the generator
    
          g_optimizer.zero_grad()
          y_pred = discriminator(X_fake)
          # use real label in the loss
          fake_loss_if_real = binary_crossentropy(y_pred, real_label * th.ones_like(y_pred))
          fake_loss_if_real.backward()
          g_optimizer.step()
    
    
    if __name__ == '__main__':
      main()
    

    Output: no_reg

    Your animations looked much better (i.e more oscillation without regularization). Is this due to a bug (I think so) or hyperparameters?

    opened by see-- 3
  • 2D-data distributions possible to recreate with code?

    2D-data distributions possible to recreate with code?

    Hello,

    First, thank you very much for your paper, it was very interesting and it helped me a lot! In your paper, in figure 8 and figure 9 in the appendices, you do some testing with 2D data distributions. Is it possible to reproduce these distributions with the code?

    Thank you in advance!

    opened by SimonVerlinden 0
  • Question about Eq. 1 in the article.

    Question about Eq. 1 in the article.

    Hi, I've recently read your paper and I think it is great. I have a question about Eq.1 in the article. I found that the equation could not be led to the original GAN's objective by replacing f with -log(1+exp(-t)). I think the negation of discriminator output in both terms on the right-hand side is switched. Am I correct? image

    opened by yoon28 0
  • Why use different architectures for ImageNet, celeba and church?

    Why use different architectures for ImageNet, celeba and church?

    Hi, thank you for sharing the code. It seems like that you used a large model for ImageNet, and a ralatively small one for the other dataset, e.g. celeba, chruch. Is it true that applying the much complex model (used for ImageNet) on celeba will produce a much better results? I am very curious about your selection of the model for different datasets, can you share it with me? Thank you very much.

    opened by MiaoyunZhao 0
  • It only support Standard GAN loss and WGAN-GP ??

    It only support Standard GAN loss and WGAN-GP ??

    Hi, I'm using gradient penalty with LSGAN loss but it did not work ,so I replace LSGAN with Standard GAN loss it work fine !!

    except WGAN loss which type gan loss it support ??

    RaGAN ?

    opened by Johnson-yue 0
  • Questions with batch normalization

    Questions with batch normalization

    WechatIMG57 Hello, I just added BN in resblock and run 'python train.py configs/celebA-HQ', then it cannot converge at all.(see the picture below)

    00006000

    I noticed the moving average in your training, but I think its influence to BN is quite small. I cannot understand the reason why adding BN cannot converge. Looking forward to your reply!

    opened by LuChengTHU 1
Owner
Lars Mescheder
Lars Mescheder
Code of paper: A Recurrent Vision-and-Language BERT for Navigation

Recurrent VLN-BERT Code of the Recurrent-VLN-BERT paper: A Recurrent Vision-and-Language BERT for Navigation Yicong Hong, Qi Wu, Yuankai Qi, Cristian

YicongHong 109 Dec 21, 2022
Code for the paper "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer"

T5: Text-To-Text Transfer Transformer The t5 library serves primarily as code for reproducing the experiments in Exploring the Limits of Transfer Lear

Google Research 4.6k Jan 1, 2023
Code for the paper "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer"

T5: Text-To-Text Transfer Transformer The t5 library serves primarily as code for reproducing the experiments in Exploring the Limits of Transfer Lear

Google Research 3.2k Feb 17, 2021
Code associated with the "Data Augmentation using Pre-trained Transformer Models" paper

Data Augmentation using Pre-trained Transformer Models Code associated with the Data Augmentation using Pre-trained Transformer Models paper Code cont

null 44 Dec 31, 2022
Code for CVPR 2021 paper: Revamping Cross-Modal Recipe Retrieval with Hierarchical Transformers and Self-supervised Learning

Revamping Cross-Modal Recipe Retrieval with Hierarchical Transformers and Self-supervised Learning This is the PyTorch companion code for the paper: A

Amazon 69 Jan 3, 2023
This repository will contain the code for the CVPR 2021 paper "GIRAFFE: Representing Scenes as Compositional Generative Neural Feature Fields"

GIRAFFE: Representing Scenes as Compositional Generative Neural Feature Fields Project Page | Paper | Supplementary | Video | Slides | Blog | Talk If

null 1.1k Dec 27, 2022
Code for ACL 2021 main conference paper "Conversations are not Flat: Modeling the Intrinsic Information Flow between Dialogue Utterances".

Conversations are not Flat: Modeling the Intrinsic Information Flow between Dialogue Utterances This repository contains the code and pre-trained mode

ICTNLP 90 Dec 27, 2022
Code from the paper "High-Performance Brain-to-Text Communication via Handwriting"

Code from the paper "High-Performance Brain-to-Text Communication via Handwriting"

Francis R. Willett 305 Dec 22, 2022
source code for paper: WhiteningBERT: An Easy Unsupervised Sentence Embedding Approach.

WhiteningBERT Source code and data for paper WhiteningBERT: An Easy Unsupervised Sentence Embedding Approach. Preparation git clone https://github.com

null 49 Dec 17, 2022
Pytorch code for ICRA'21 paper: "Hierarchical Cross-Modal Agent for Robotics Vision-and-Language Navigation"

Hierarchical Cross-Modal Agent for Robotics Vision-and-Language Navigation This repository is the pytorch implementation of our paper: Hierarchical Cr

null 44 Jan 6, 2023
Code for our paper "Mask-Align: Self-Supervised Neural Word Alignment" in ACL 2021

Mask-Align: Self-Supervised Neural Word Alignment This is the implementation of our work Mask-Align: Self-Supervised Neural Word Alignment. @inproceed

THUNLP-MT 46 Dec 15, 2022
Code for our ACL 2021 paper - ConSERT: A Contrastive Framework for Self-Supervised Sentence Representation Transfer

ConSERT Code for our ACL 2021 paper - ConSERT: A Contrastive Framework for Self-Supervised Sentence Representation Transfer Requirements torch==1.6.0

Yan Yuanmeng 478 Dec 25, 2022
Code for our ACL 2021 (Findings) Paper - Fingerprinting Fine-tuned Language Models in the wild .

?? Fingerprinting Fine-tuned Language Models in the wild This is the code and dataset for our ACL 2021 (Findings) Paper - Fingerprinting Fine-tuned La

LCS2-IIITDelhi 5 Sep 13, 2022
Code for our paper "Transfer Learning for Sequence Generation: from Single-source to Multi-source" in ACL 2021.

TRICE: a task-agnostic transferring framework for multi-source sequence generation This is the source code of our work Transfer Learning for Sequence

THUNLP-MT 9 Jun 27, 2022
Code and datasets for our paper "PTR: Prompt Tuning with Rules for Text Classification"

PTR Code and datasets for our paper "PTR: Prompt Tuning with Rules for Text Classification" If you use the code, please cite the following paper: @art

THUNLP 118 Dec 30, 2022
null 189 Jan 2, 2023
This is the code for the EMNLP 2021 paper AEDA: An Easier Data Augmentation Technique for Text Classification

The baseline code is for EDA: Easy Data Augmentation techniques for boosting performance on text classification tasks

Akbar Karimi 81 Dec 9, 2022
This repository contains the code for EMNLP-2021 paper "Word-Level Coreference Resolution"

Word-Level Coreference Resolution This is a repository with the code to reproduce the experiments described in the paper of the same name, which was a

null 79 Dec 27, 2022