GAN JAX - A toy project to generate images from GANs with JAX

Related tags

Deep Learning GANJax
Overview

GAN JAX - A toy project to generate images from GANs with JAX

This project aims to bring the power of JAX, a Python framework developped by Google and DeepMind to train Generative Adversarial Networks for images generation.

JAX

JAX logo

JAX is a framework developed by Deep-Mind (Google) that allows to build machine learning models in a more powerful (XLA compilation) and flexible way than its counterpart Tensorflow, using a framework almost entirely based on the nd.array of numpy (but stored on the GPU, or TPU if available). It also provides new utilities for gradient computation (per sample, jacobian with backward propagation and forward-propagation, hessian...) as well as a better seed system (for reproducibility) and a tool to batch complicated operations automatically and efficiently.

Github link: https://github.com/google/jax

GAN

GAN diagram

Generative adversarial networks (GANs) are algorithmic architectures that use two neural networks, pitting one against the other (thus the adversarial) in order to generate new, synthetic instances of data that can pass for real data. They are used widely in image generation, video generation and voice generation. GANs were introduced in a paper by Ian Goodfellow and other researchers at the University of Montreal, including Yoshua Bengio, in 2014. Referring to GANs, Facebook’s AI research director Yann LeCun called adversarial training the most interesting idea in the last 10 years in ML. (source)

Original paper: https://arxiv.org/abs/1406.2661

Some ideas have improved the training of the GANs by the years. For example:

Deep Convolution GAN (DCGAN) paper: https://arxiv.org/abs/1511.06434

Progressive Growing GAN (ProGAN) paper: https://arxiv.org/abs/1710.10196

The goal of this project is to implement these ideas in JAX framework.

Installation

You can install JAX following the instruction on JAX - Installation

It is strongly recommended to run JAX on Linux with CUDA available (Windows has no stable support yet). In this case you can install JAX using the following command:

pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Then you can install Tensorflow to benefit from tf.data.Dataset to handle the data and the pre-installed dataset. However, Tensorfow allocate memory of the GPU on use (which is not optimal for running calculation with JAX). Therefore, you should install Tensorflow on the CPU instead of the GPU. Visit this site Tensorflow - Installation with pip to install the CPU-only version of Tensorflow 2 depending on your OS and your Python version.

Exemple with Linux and Python 3.9:

pip install tensorflow -f https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow_cpu-2.6.0-cp39-cp39-manylinux2010_x86_64.whl

Then you can install the other librairies from requirements.txt. It will install Haiku and Optax, two usefull add-on libraries to implement and optimize machine learning models with JAX.

pip install -r requirements.txt

Install CelebA dataset (optional)

To use the CelebA dataset, you need to download the dataset from Kaggle and install the images in the folder img_align_celeba/ in data/CelebA/images. It is recommended to download the dataset from this source because the faces are already cropped.

Note: the other datasets will be automatically installed with keras or tensorflow-datasets.

Quick Start

You can test a pretrained GAN model by using apps/test.py. It will download the model from pretrained models (in pre_trained/) and generate pictures. You can change the GAN to test by changing the path in the script.

You can also train your own GAN from scratch with apps/train.py. To change the parameters of the training, you can change the configs in the script. You can also change the dataset or the type of GAN by changing the imports (there is only one workd to change for each).

Example to train a GAN in celeba (64x64):

from utils.data import load_images_celeba_64 as load_images

To train a DCGAN:

from gan.dcgan import DCGAN as GAN

Then you can implement your own GAN and train/test them in your own dataset (by overriding the appropriate functions, check the examples in the repository).

Some results of pre-trained models

- Deep Convolution GAN

  • On MNIST:

DCGAN Cifar10

  • On Cifar10:

DCGAN Cifar10

  • On CelebA (64x64):

DCGAN CelebA-64

- Progressive Growing GAN

  • On MNIST:

  • On Cifar10:

  • On CelebA (64x64):

  • On CelebA (128x128):

Comments
  • FID implementation

    FID implementation

    This implements Frechet Inception Distance (https://github.com/matthias-wright/jax-fid). Inspired and uses https://github.com/matthias-wright/jax-fid.

    To use it, use command "python apps/test.py --eval FID --dataset_path your/path/to/dataset".

    opened by AntoineAwaida 1
  • Reproducibility

    Reproducibility

    https://github.com/valentingol/GANJax/blob/f66a8707d2b2fa0298fa0257a91eeb1e4464b747/utils/data.py#L70

    Reproduire le bug

    1. J'ai créé un mini dataset dans data/CelebA/miniSet contenant 50 images du dataset complet.
    2. J'ai modifié load_images_celeba_xx dans utils/data.py pour loader automatiquement les images de ce mini data set.
    3. J'ai effectué trois expériences :
    • J'ai run 2 fois app/train.py avec batch_size = 5, num_epoch = 1 en faisant en sorte d'enregistrer l'historique des loss du generateur et du discriminateur dans des fichiers .npy différents à chaque fois. Après les deux runs, je me retrouve avec 4 fichiers : gen_history.npy, disc_history.npy, gen_history2.npy et disc_history2.npy.
    • Pour le dernier run, je suis retourné dans les data loaders que j'ai modifiés ainsi de sorte à avoir une buffer size inférieure à la taille du dataset (désormais de 50) :
    dataset = dataset.shuffle(buffer_size=5, seed=seed) # Changed buffer size from 1000 to 5
    

    Les historiques des loss pour cette dernières expériences ont été stockés dans : gen_history3.npy et disc_history3.npy.

    Constater les différences

    Avec ce bout de code, j'ai loadé tous les historiques :

    import jax.numpy as jnp
    
    path = '' # Path to where the arrays were stored
    
    gen_loss_history = jnp.load(
        f'{path}/gen_history.npy')
    disc_loss_history = jnp.load(
       f'{path}/disc_history.npy')
    
    gen_loss_history2 = jnp.load(
       f'{path}/gen_history2.npy')
    disc_loss_history2 = jnp.load(
        f'{path}/disc_history2.npy')
    
    gen_loss_history3 = jnp.load(
        f'{path}/gen_history3.npy')
    disc_loss_history3 = jnp.load(
       f'{path}/disc_history3.npy')
    
    print(gen_loss_history == gen_loss_history2) # Retourne du true partout
    print(gen_loss_history == gen_loss_history3) # Retourne presque que du False
    

    Conclusion

    Je pense donc que tu avais raison @valentingol, le pb venait du chargement des images avec tf et de cette fameuse buffer size qui doit être de taille supérieure à celle du dataset d'après la doc officielle.

    opened by Bassvelitchkine 1
  • Fixed reproducibility in tf data loading

    Fixed reproducibility in tf data loading

    Fixes #11

    For the CelebA dataset, I just made sure that the buffer size of tf's data loader was larger than the dataset itself, to fix the reproducibility bug.

    opened by Bassvelitchkine 0
  • Initial Commit

    Initial Commit

    • add .gitignore, README, LICENSE
    • add setup.py
    • add utils for utilitary functions
    • add apps/train to train any type of GAN
    • add apps/test to test any type of GAN
    • DCGAN implementation: - add gan/dcgan main module for DCGAN - add pretrained CIFAR-10 model
    opened by valentingol 0
  • FID implementation

    FID implementation

    This is an implementation of Frechet Inception Distance (FID).

    To use it :

    1. save generated images, for instance using apps/test.py script (by specifying a save_images_path argument to save your .pkl file of generated images)
    2. launch the evaluation script with following command : python apps/evals.py --eval FID --dataset_path data/CelebA/img_align_celeba --generated_path ./generated/images.pkl with dataset_path the path towards your directory of real images, and generated_path your path towards you generated images.
    opened by AntoineAwaida 0
  • Progressive Growing GAN implementation

    Progressive Growing GAN implementation

    Implement Progressive Growing GAN (https://arxiv.org/abs/1710.10196). Run it in MNIST, CelebA and CIFAR-10, add pretrained models and update the readme.

    enhancement 
    opened by valentingol 0
Owner
Valentin Goldité
Student at CentraleSupelec (top french Engineer School) specialized in machine learning (Computer Vision, NLP, Audio, RL, Time Analysis).
Valentin Goldité
A toy project using OpenCV and PyMunk

A toy project using OpenCV, PyMunk and Mediapipe the source code for my LindkedIn post It's just a toy project and I didn't write a documentation yet,

Amirabbas Asadi 82 Oct 28, 2022
simple_pytorch_example project is a toy example of a python script that instantiates and trains a PyTorch neural network on the FashionMNIST dataset

simple_pytorch_example project is a toy example of a python script that instantiates and trains a PyTorch neural network on the FashionMNIST dataset

Ramón Casero 1 Jan 7, 2022
[CVPR 2021] Pytorch implementation of Hijack-GAN: Unintended-Use of Pretrained, Black-Box GANs

Hijack-GAN: Unintended-Use of Pretrained, Black-Box GANs In this work, we propose a framework HijackGAN, which enables non-linear latent space travers

Hui-Po Wang 46 Sep 5, 2022
GAN encoders in PyTorch that could match PGGAN, StyleGAN v1/v2, and BigGAN. Code also integrates the implementation of these GANs.

MTV-TSA: Adaptable GAN Encoders for Image Reconstruction via Multi-type Latent Vectors with Two-scale Attentions. This is the official code release fo

owl 37 Dec 24, 2022
Official pytorch code for SSC-GAN: Semi-Supervised Single-Stage Controllable GANs for Conditional Fine-Grained Image Generation(ICCV 2021)

SSC-GAN_repo Pytorch implementation for 'Semi-Supervised Single-Stage Controllable GANs for Conditional Fine-Grained Image Generation'.PDF SSC-GAN:Sem

tyty 4 Aug 28, 2022
PyTorch implementation for OCT-GAN Neural ODE-based Conditional Tabular GANs (WWW 2021)

OCT-GAN: Neural ODE-based Conditional Tabular GANs (OCT-GAN) Code for reproducing the experiments in the paper: Jayoung Kim*, Jinsung Jeon*, Jaehoon L

BigDyL 7 Dec 27, 2022
An image base contains 490 images for learning (400 cars and 90 boats), and another 21 images for testingAn image base contains 490 images for learning (400 cars and 90 boats), and another 21 images for testing

SVM Données Une base d’images contient 490 images pour l’apprentissage (400 voitures et 90 bateaux), et encore 21 images pour fait des tests. Prétrait

Achraf Rahouti 3 Nov 30, 2021
This repository implements and evaluates convolutional networks on the Möbius strip as toy model instantiations of Coordinate Independent Convolutional Networks.

Orientation independent Möbius CNNs This repository implements and evaluates convolutional networks on the Möbius strip as toy model instantiations of

Maurice Weiler 59 Dec 9, 2022
A toy compiler that can convert Python scripts to pickle bytecode 🥒

Pickora ?? A small compiler that can convert Python scripts to pickle bytecode. Requirements Python 3.8+ No third-party modules are required. Usage us

ꌗᖘ꒒ꀤ꓄꒒ꀤꈤꍟ 68 Jan 4, 2023
Some toy examples of score matching algorithms written in PyTorch

toy_gradlogp This repo implements some toy examples of the following score matching algorithms in PyTorch: ssm-vr: sliced score matching with variance

Ending Hsiao 21 Dec 26, 2022
Pytoydl: A toy deep learning framework built upon numpy.

Documents: https://pytoydl.readthedocs.io/zh/latest/ Pytoydl A toy deep learning framework built upon numpy. You can star this repository to keep trac

null 28 Dec 10, 2022
FuseDream: Training-Free Text-to-Image Generationwith Improved CLIP+GAN Space OptimizationFuseDream: Training-Free Text-to-Image Generationwith Improved CLIP+GAN Space Optimization

FuseDream This repo contains code for our paper (paper link): FuseDream: Training-Free Text-to-Image Generation with Improved CLIP+GAN Space Optimizat

XCL 191 Dec 31, 2022
DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time

DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time Introduction This is official implementation for DR-GAN (IEEE TCS

Kang Liao 18 Dec 23, 2022
Synthesizing and manipulating 2048x1024 images with conditional GANs

pix2pixHD Project | Youtube | Paper Pytorch implementation of our method for high-resolution (e.g. 2048x1024) photorealistic image-to-image translatio

NVIDIA Corporation 6k Dec 27, 2022
Official code for our ICCV paper: "From Continuity to Editability: Inverting GANs with Consecutive Images"

GANInversion_with_ConsecutiveImgs Official code for our ICCV paper: "From Continuity to Editability: Inverting GANs with Consecutive Images" https://a

QingyangXu 38 Dec 7, 2022
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
Combine Tacotron2 and Hifi GAN to generate speech from text

EndToEndTextToSpeech Combine Tacotron2 and Hifi GAN to generate speech from text Download weights Hifi GAN -> hifi_gan/checkpoint/ : pretrain 2.5M ste

Phạm Quốc Huy 1 Dec 18, 2021
Generate high quality pictures. GAN. Generative Adversarial Networks

ESRGAN generate high quality pictures. GAN. Generative Adversarial Networks """ Super-resolution of CelebA using Generative Adversarial Networks. The

Lieon 1 Dec 14, 2021