A stable algorithm for GAN training

Overview

DRAGAN (Deep Regret Analytic Generative Adversarial Networks)

Link to our paper - https://arxiv.org/abs/1705.07215

Pytorch implementation (thanks!) - https://github.com/jfsantos/dragan-pytorch

Procedure (to use our algorithm):

  1. Pick your favorite architecture, objective function for the game.
  2. Tune the hyperparameter 'c' which decides the size of local regions. Our intuition is that small values extract better performance from a given architecture due to relaxed restrictions while slightly larger values give more stability. Be careful to set it appropriately by taking into account your domain range and making sure that perturbations don't lie on the data manifold.
  3. Tune lambda if necessary, this has the usual meaning of regularization intensity. Set 'k' to be 1.
  4. If your results are still bad, go back to Step 1 and try a different architecture+objective.

Interesting discussion with Ian Goodfellow and Martin Arjovsky on why GANs are unstable and where improvements come from

https://www.facebook.com/kodali.naveen.90/posts/1047257878740881

An interesting new paper by Fedus et.al came out following this (Many paths to equilibria)

https://arxiv.org/abs/1710.08446

Some of the repositories that would be helpful and which helped us (big thanks!):

https://github.com/igul222/improved_wgan_training

https://github.com/wiseodd/generative-models/tree/master/GAN

https://github.com/openai/improved-gan/tree/master/inception_score

Comments
  • About gradient penalty

    About gradient penalty

    Hi, I see that in your paper, the discriminator is updated with

    I think the D appeared here means a "discriminator with sigmoid", But in the implementation, we always use "tf.nn.sigmoid_cross_entropy_with_logits" instead of explicit calculation of the sigmoid function.

    In your paper, I think the gradient penalty is based on "discriminator with sigmoid" But in your codes, gradients = tf.gradients(discriminator(interpolates), [interpolates])[0] this is the gradient term w.r.t "discriminator without sigmoid"

    opened by Aixile 2
  • Difference between paper and code in obtaining perturbed inputs

    Difference between paper and code in obtaining perturbed inputs

    In your paper (page 6. equation 1), you calculate the gradient penalty in this way:

    grafik

    So it looks like you simply add pixelwise Gaussian Noise with mean 0 and standard-deviation c to the inputs of the discriminator. Furthermore you suggest to use c = 10

    However, in your code, you you calculate the perturbed inputs in this way:

    def get_perturbed_batch(minibatch):
        return minibatch + 0.5 * minibatch.std() * np.random.random(minibatch.shape)
    

    So here you add pixelwise uniform noise from the interval [0.0, 0.5 * minibatch.std() ), which is very different to the formulation of noise in the paper. Furthermore your code still contains the interpolation issue from https://github.com/kodalinaveen3/DRAGAN/issues/5

    Of course you wrote "We use small pixel-level noise but it is possible to find better ways of imposing this penalty. However, this exploration is beyond the scope of our paper." But maybe there are important reasons why you use this type of X_p generation in the code?

    opened by Netzeband 1
  • Why is interpolation necessary?

    Why is interpolation necessary?

    interpolates = X + (alpha*differences)
    = X + U[0, 1] (X_p - X)
    = X + U[0, 1](X + 0.5*std * U[0,1] - X)
    = X + U[0, 1] * 0.5 std * U[0, 1]
    = X + 0.5 * std * U[0, 1]^2
    ~= X_p
    

    so it appears that X_p can be directly used instead of interpolates.

    And the paper actually used X_p instead of interpolates.

    opened by shaform 1
  • Is the gradient penalty loss problematic when input image is large?

    Is the gradient penalty loss problematic when input image is large?

    Hi, Here the sum of squares is computed slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))

    However, when the input image is extremely large, the dimension of gradients would be huge. It seems possible that the resulting slopes would be extremely large compared to 1. Is this the case?

    opened by shaform 1
  • NAN issue

    NAN issue

    When I try to apply this loss function to my model, modified from DCGAN, I got NAN loss, I wonder will this modified vanilla function will lead to this issue? Thanks!

    opened by yuanpengX 1
  • If you're trying to implement dragan, see some of the closed issues - the code here is old and is not a final implementation of the paper

    If you're trying to implement dragan, see some of the closed issues - the code here is old and is not a final implementation of the paper

    I'm not sure if this is abandoned or not, but I spent a considerable amount of time trying to replicate the notebook without looking at the closed issues.

    Teaches me a lesson not to look at the closed issues!

    See (which have been answered by the authors) https://github.com/kodalinaveen3/DRAGAN/issues/6 https://github.com/kodalinaveen3/DRAGAN/issues/5

    It also looks like several implementations have implemented based on the old version of the code, e.g. https://github.com/opendp/smartnoise-sdk/blob/main/synth/snsynth/pytorch/nn/patectgan.py

    opened by ssabdb 0
  • Update DRAGAN.ipynb

    Update DRAGAN.ipynb

    One argument of tf.nn.sigmoid_cross_entropy_with_logits() seems changed from "targets" to "labels" in recent released version of Tensorflow. Refer to https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits

    opened by hoshino042 0
  • Is this right ?

    Is this right ?

    Hi his DRAGAN Code for Tensorflow

    In his code, you said that alpha should have a value between -1 and 1, right? The alpha between 0 and 1 is bug ... ??????? However, Most of the code seems to write alpha between 0 and 1.

    What is right?

    opened by taki0112 0
Owner
null
Genetic Algorithm, Particle Swarm Optimization, Simulated Annealing, Ant Colony Optimization Algorithm,Immune Algorithm, Artificial Fish Swarm Algorithm, Differential Evolution and TSP(Traveling salesman)

scikit-opt Swarm Intelligence in Python (Genetic Algorithm, Particle Swarm Optimization, Simulated Annealing, Ant Colony Algorithm, Immune Algorithm,A

郭飞 3.7k Jan 3, 2023
A Fast and Stable GAN for Small and High Resolution Imagesets - pytorch

A Fast and Stable GAN for Small and High Resolution Imagesets - pytorch The official pytorch implementation of the paper "Towards Faster and Stabilize

Bingchen Liu 455 Jan 8, 2023
Self-driving car env with PPO algorithm from stable baseline3

Self-driving car with RL stable baseline3 Most of the project develop from https://github.com/GerardMaggiolino/Gym-Medium-Post Please check it out! Th

Sornsiri.P 7 Dec 22, 2022
Learning to Initialize Neural Networks for Stable and Efficient Training

GradInit This repository hosts the code for experiments in the paper, GradInit: Learning to Initialize Neural Networks for Stable and Efficient Traini

Chen Zhu 124 Dec 30, 2022
Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly

Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly Code for this paper Ultra-Data-Efficient GAN Tra

VITA 77 Oct 5, 2022
DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time

DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time Introduction This is official implementation for DR-GAN (IEEE TCS

Kang Liao 18 Dec 23, 2022
PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.

DLR-RM 4.7k Jan 1, 2023
This is the PyTorch implementation of GANs N’ Roses: Stable, Controllable, Diverse Image to Image Translation

Official PyTorch repo for GAN's N' Roses. Diverse im2im and vid2vid selfie to anime translation.

null 1.1k Jan 1, 2023
Lyapunov-guided Deep Reinforcement Learning for Stable Online Computation Offloading in Mobile-Edge Computing Networks

PyTorch code to reproduce LyDROO algorithm [1], which is an online computation offloading algorithm to maximize the network data processing capability subject to the long-term data queue stability and average power constraints. It applies Lyapunov optimization to decouple the multi-stage stochastic MINLP into deterministic per-frame MINLP subproblems and solves each subproblem via DROO algorithm. It includes:

Liang HUANG 87 Dec 28, 2022
ElegantRL is featured with lightweight, efficient and stable, for researchers and practitioners.

Lightweight, efficient and stable implementations of deep reinforcement learning algorithms using PyTorch. ??

AI4Finance 2.5k Jan 8, 2023
Official repository for CVPR21 paper "Deep Stable Learning for Out-Of-Distribution Generalization".

StableNet StableNet is a deep stable learning method for out-of-distribution generalization. This is the official repo for CVPR21 paper "Deep Stable L

null 120 Dec 28, 2022
This is the official implementation of the paper "Object Propagation via Inter-Frame Attentions for Temporally Stable Video Instance Segmentation".

[CVPRW 2021] - Object Propagation via Inter-Frame Attentions for Temporally Stable Video Instance Segmentation

Anirudh S Chakravarthy 6 May 3, 2022
TeST: Temporal-Stable Thresholding for Semi-supervised Learning

TeST: Temporal-Stable Thresholding for Semi-supervised Learning TeST Illustration Semi-supervised learning (SSL) offers an effective method for large-

Xiong Weiyu 1 Jul 14, 2022
Simple converter for deploying Stable-Baselines3 model to TFLite and/or Coral

Running SB3 developed agents on TFLite or Coral Introduction I've been using Stable-Baselines3 to train agents against some custom Gyms, some of which

Gary Briggs 16 Oct 11, 2022
RL agent to play μRTS with Stable-Baselines3

Gym-μRTS with Stable-Baselines3/PyTorch This repo contains an attempt to reproduce Gridnet PPO with invalid action masking algorithm to play μRTS usin

Oleksii Kachaiev 24 Nov 11, 2022
Additional code for Stable-baselines3 to load and upload models from the Hub.

Hugging Face x Stable-baselines3 A library to load and upload Stable-baselines3 models from the Hub. Installation With pip Examples [Todo: add colab t

Hugging Face 34 Dec 10, 2022
An implementation of the [Hierarchical (Sig-Wasserstein) GAN] algorithm for large dimensional Time Series Generation

Hierarchical GAN for large dimensional financial market data Implementation This repository is an implementation of the [Hierarchical (Sig-Wasserstein

null 11 Nov 29, 2022
GUI for TOAD-GAN, a PCG-ML algorithm for Token-based Super Mario Bros. Levels.

If you are using this code in your own project, please cite our paper: @inproceedings{awiszus2020toadgan, title={TOAD-GAN: Coherent Style Level Gene

Maren A. 13 Dec 14, 2022
Partial implementation of ODE-GAN technique from the paper Training Generative Adversarial Networks by Solving Ordinary Differential Equations

ODE GAN (Prototype) in PyTorch Partial implementation of ODE-GAN technique from the paper Training Generative Adversarial Networks by Solving Ordinary

Somshubra Majumdar 15 Feb 10, 2022