GAN Ensembling
Project Page | Paper | Bibtex
Ensembling with Deep Generative Views.
Lucy Chai, Jun-Yan Zhu, Eli Shechtman, Phillip Isola, Richard Zhang
CVPR 2021
Prerequisites
- Linux
- Python 3
- NVIDIA GPU + CUDA CuDNN
Table of Contents:
- Colab - run a limited demo version without local installation
- Setup - download required resources
- Quickstart - short demonstration code snippet
- Notebooks - jupyter notebooks for visualization
- Pipeline - details on full pipeline
Colab
This Colab Notebook demonstrates the basic latent code perturbation and classification procedure in a simplified setting on the aligned cat dataset.
Setup
- Clone this repo:
git clone https://github.com/chail/gan-ensembling.git
cd gan-ensembling
-
Install dependencies:
- we provide a Conda
environment.yml
file listing the dependencies. You can create the Conda environment using:
conda env create -f environment.yml
- we provide a Conda
-
Download resources:
- We provide a script for downloading associated resources.
- It will download precomputed latent codes (cat: 291M, car: 121M, celebahq: 1.8G, cifar10: 883M), a subset of trained models (592M), precomputed results (1.3G), and associated libraries.
- Fetch the resources by running
bash resources/download_resources.sh
- Note, Optional: to run the StyleGAN ID-invert models, the models need to be downloaded separately. Follow the directions here to obtain
styleganinv_ffhq256_encoder.pth
andstyleganinv_ffhq256_encoder.pth
, and place them inmodels/pretrain
- Note, Optional: the download script downloads a subset of the pretrained models for the demo notebook. For further experiments, the additional pretrained models (total 7.0G) can be downloaded here; it includes 40 binary face attribute classifiers, and classifiers trained on the different perturbation methods for the remaining datasets.
-
Download external datasets:
- CelebA-HQ: Follow the instructions here to create the CelebA-HQ dataset and place CelebA-HQ images in directory
dataset/celebahq/images/images
. - Cars: This dataset is a subset of Cars196. Download the images from here and the devkit from here. (We are subsetting their training images into train/test/val partitions). Place the images in directory
dataset/cars/images/images
and the devkit indataset/cars/devkit
. - The processed and aligned cat images are downloaded with the above resources, and cifar10 dataset is downloaded via the PyTorch wrapper.
- CelebA-HQ: Follow the instructions here to create the CelebA-HQ dataset and place CelebA-HQ images in directory
An example of the directory organization is below:
dataset/celebahq/
images/images/
000004.png
000009.png
000014.png
...
latents/
latents_idinvert/
dataset/cars/
devkit/
cars_meta.mat
cars_test_annos.mat
cars_train_annos.mat
...
images/images/
00001.jpg
00002.jpg
00003.jpg
...
latents/
dataset/catface/
images/
latents/
dataset/cifar10/
cifar-10-batches-py/
latents/
Quickstart
Once the datasets and precomputed resources are downloaded, the following code snippet demonstrates how to perturb GAN images. Additional examples are contained in notebooks/demo.ipynb
.
import data
from networks import domain_generator
dataset_name = 'celebahq'
generator_name = 'stylegan2'
attribute_name = 'Smiling'
val_transform = data.get_transform(dataset_name, 'imval')
dset = data.get_dataset(dataset_name, 'val', attribute_name, load_w=True, transform=val_transform)
generator = domain_generator.define_generator(generator_name, dataset_name)
index = 100
original_image = dset[index][0][None].cuda()
latent = dset[index][1][None].cuda()
gan_reconstruction = generator.decode(latent)
mix_latent = generator.seed2w(n=4, seed=0)
perturbed_im = generator.perturb_stylemix(latent, 'fine', mix_latent, n=4)
Notebooks
Important: First, set up symlinks required for notebooks: bash notebooks/setup_notebooks.sh
, and add the conda environment to jupyter kernels: python -m ipykernel install --user --name gan-ensembling
.
The provided notebooks are:
notebooks/demo.ipynb
: basic usage examplenotebooks/evaluate_ensemble.ipynb
: plot classification test accuracy as a function of ensemble weightnotebooks/plot_precomputed_evaluations.ipynb
: notebook to generate figures in paper
Full Pipeline
The full pipeline contains three main parts:
- optimize latent codes
- train classifiers
- evaluate the ensemble of GAN-generated images.
Examples for each step of the pipeline are contained in the following scripts:
bash scripts/optimize_latent/examples.sh
bash scripts/train_classifier/examples.sh
bash scripts/eval_ensemble/examples.sh
To add to the pipeline:
- Data: in the
data/
directory, add the dataset indata/__init__.py
and create the dataset class and transformation functions. Seedata/data_*.py
for examples. - Generator: modify
networks/domain_generators.py
to add the generator indomain_generators.define_generator
. The perturbation ranges for each dataset and generator are specified innetworks/perturb_settings.py
. - Classifier: modify
networks/domain_classifiers.py
to add the classifier indomain_classifiers.define_classifier
Acknowledgements
We thank the authors of these repositories:
- Gan Seeing for GAN and visualization utilities
- StyleGAN 2 Pytorch for pytorch implementation of StyleGAN 2 and pretrained models (license)
- Stylegan 2 ADA Pytorch for the class-conditional StyleGAN 2 CIFAR10 generator (license)
- StyleGAN In-domain inversion for the in-domain stylegan generator and encoder (license)
- Pytorch CIFAR for CIFAR10 classification (license)
- Latent Composition for some code and remaining encoders (license)
- Cat dataset images are from the Oxford-IIIT Pet Dataset (license), aligned using the Frederic landmark detector (license).
Citation
If you use this code for your research, please cite our paper:
@inproceedings{chai2021ensembling,
title={Ensembling with Deep Generative Views.},
author={Chai, Lucy and Zhu, Jun-Yan and Shechtman, Eli and Isola, Phillip and Zhang, Richard},
booktitle={CVPR},
year={2021}
}