[ICLR'21] Counterfactual Generative Networks

Overview

Counterfactual Generative Networks

[Project] [PDF] [Blog] [Music Video] [Colab]

This repository contains the code for the ICLR 2021 paper "Counterfactual Generative Networks" by Axel Sauer and Andreas Geiger. If you want to take the CGN for a spin and generate counterfactual images, you can try out the Colab below.

CGN

If you find our code or paper useful, please cite

@inproceedings{Sauer2021ICLR,
 author =  {Axel Sauer, Andreas Geiger},
 title = {Counterfactual Generative Networks},
 booktitle = {International Conference on Learning Representations (ICLR)},
 year = {2021}}

Setup

Install anaconda (if you don't have it yet)

wget https://repo.anaconda.com/archive/Anaconda3-2020.11-Linux-x86_64.sh
bash Anaconda3-2020.11-Linux-x86_64.sh
source ~/.profile

Clone the repo and build the environment

git clone https://github.com/autonomousvision/counterfactual_generative_networks
cd counterfactual_generative_networks
conda env create -f environment.yml
conda activate cgn

Make all scripts executable: chmod +x scripts/*. Then, download the datasets (colored MNIST, Cue-Conflict, IN-9) and the pre-trained weights (CGN, U2-Net). Comment out the ones you don't need.

./scripts/download_data.sh
./scripts/download_weights.sh

MNISTs

The main functions of this sub-repo are:

  • Generating the MNIST variants
  • Training a CGN
  • Generating counterfactual datasets
  • Training a shape classifier

Train the CGN

We provide well-working configs and weights in mnists/experiments. To train a CGN on, e.g., Wildlife MNIST, run

python mnists/train_cgn.py --cfg mnists/experiments/cgn_wildlife_MNIST/cfg.yaml

For more info, add --help. Weights and samples will be saved in mnists/experiments/.

Generate Counterfactual Data

To generate the counterfactuals for, e.g., double-colored MNIST, run

python mnists/generate_data.py \
--weight_path mnists/experiments/cgn_double_colored_MNIST/weights/ckp.pth \
--dataset double_colored_MNIST --no_cfs 10 --dataset_size 100000

Make sure that you provide the right dataset together with the weights. You can adapt the weight-path to use your own weights. The command above generates ten counterfactuals per shape.

Train the Invariant Classifier

The classifier training uses Tensor datasets, so you need to save the non-counterfactual datasets as tensors. For DATASET = {colored_MNIST, double_colored_MNIST, wildlife_MNIST}, run

python mnists/generate_data.py --dataset DATASET

To train, e.g., a shape classifier (invariant to foreground and background) on wildlife MNIST, run,

python mnists/train_classifier.py --dataset wildlife_MNIST_counterfactual

Add --help for info on the available options and arguments. The hyperparameters are unchanged for all experiments.

ImageNet

The main functions of this sub-repo are:

  • Training a CGN
  • Generating data (samples, interpolations, or a whole dataset)
  • Training an invariant classifier ensemble

Train the CGN

Run

python imagenet/train_cgn.py --model_name MODEL_NAME

The default parameters should give you satisfactory results. You can change them in imagenet/config.yml. For more info, add --help. Weights and samples will be saved in imagenet/data/MODEL_NAME.

Generate Counterfactual Data

Samples. To generate a dataset of counterfactual images, run

python imagenet/generate_data.py --mode random --weights_path imagenet/weights/cgn.pth \
--n_data 100 --weights_path imagenet/weights/cgn.pth --run_name RUN_NAME \
--truncation 0.5 --batch_sz 1

The results will be saved in imagenet/data. For more info, add --help. If you want to save only masks, textures, etc., you need to change this directly in the code (see line 206).

The labels will be stored in a csv file. You can read them as follows:

import pandas as pd
df = pd.read_csv(path, index_col=0)
df = df.set_index('im_name')
shape_cls = df['shape_cls']['RUN_NAME_0000000.png']

Generating a dataset to train a classfier. Produce one dataset with --run_name train, the other one with --run_name val. If you have several GPUs available, you can index the name, e.g., --run_name train_GPU_NUM. The class ImagenetCounterfactual will glob all these datasets and generate a single, big training set. Make sure to set --batch_sz 1. With a larger batch size, a batch will be saved as a single png; this is useful for visualization, not for training.

Interpolations. To generate interpolation sheets, e.g., from a barn (425) to whale (147), run

python imagenet/generate_data.py --mode fixed_classes \
--n_data 1 --weights_path imagenet/weights/cgn.pth --run_name barn_to_whale \
--truncation 0.3 --interp all --classes 425 425 425 --interp_cls 147 --save_noise

You can also do counterfactual interpolations, i.e., interpolating only over, e.g., shape, by setting --interp shape.

Interpolation Gif. To generate a gif like in the teaser (Sample an image of class $1, than interpolate to shape $2, then background $3, then shape $4, and finally back to $1), run

./scripts/generate_teaser_gif.sh 992 293 147 330

The positional arguments are the classes, see imagenet labels for the available options.

Train the Invariant Classifier Ensemble

Training. First, you need to make sure that you have all datasets in imagenet/data/. Download Imagenet, e.g., from Kaggle, produce a counterfactual dataset (see above), and download the Cue-Conflict and BG-Challenge dataset (via the download script in scripts).

To train a classifier on a single GPU with a pre-trained Resnet-50 backbone, run

python imagenet/train_classifier.py -a resnet50 -b 32 --lr 0.001 -j 6 \
--epochs 45 --pretrained --cf_data CF_DATA_PATH --name RUN_NAME

Again, add --help for more information on the possible arguments.

Distributed Training. To switch to multi-GPU training, run echo $CUDA_VISIBLE_DEVICES to see if the GPUs are visible. In the case of a single node with several GPUs, you can run, e.g.,

python imagenet/train_classifier.py -a resnet50 -b 256 --lr 0.001 -j 6 \
--epochs 45 --pretrained --cf_data CF_DATA_PATH --name RUN_NAME \
--rank 0 --multiprocessing-distributed --dist-url tcp://127.0.0.1:8890 --world-size 1

If your setup differs, e.g., several GPU machines, you need to adapt the rank and world size.

Visualization. To visualize the Tensorboard outputs, run tensorboard --logdir=imagenet/runs and open the local address in your browser.

Acknowledgments

We like to acknowledge several repos of which we use parts of code, data, or models in our implementation:

Comments
  • Bad key

    Bad key "text.kerning_factor" - cgn_colored_MNIST

    Thanks for the great code and paper.

    python mnists/train_cgn.py --cfg mnists/experiments/cgn_colored_MNIST/cfg.yaml
    
    Bad key "text.kerning_factor" on line 4 in
    /home/local/AD/cordun1/anaconda3/envs/gans/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test_patch.mplstyle.
    You probably need to get an updated matplotlibrc file from
    https://github.com/matplotlib/matplotlib/blob/v3.1.3/matplotlibrc.template
    or from the matplotlib source distribution
    
    initialize network with orthogonal
    Traceback (most recent call last):
      File "mnists/train_cgn.py", line 179, in <module>
        main(cfg)
      File "mnists/train_cgn.py", line 131, in main
        cfg.TRAIN.WORKERS)
      File "/home/local/AD/cordun1/experiments/causality/counterfactual_generative_networks/mnists/dataloader.py", line 182, in get_dataloaders
        ds_train = MNIST(train=True)
      File "/home/local/AD/cordun1/experiments/causality/counterfactual_generative_networks/mnists/dataloader.py", line 17, in __init__
        data_dic = np.load(self.data_path, encoding='latin1', allow_pickle=True).item()
      File "/home/local/AD/cordun1/anaconda3/envs/gans/lib/python3.7/site-packages/numpy/lib/npyio.py", line 416, in load
        fid = stack.enter_context(open(os_fspath(file), "rb"))
    FileNotFoundError: [Errno 2] No such file or directory: 'mnists/data/colored_mnist/mnist_10color_jitter_var_0.020.npy'
    
    opened by nudro 2
  • Implementation of DoubleColoredMNIST

    Implementation of DoubleColoredMNIST

    Hi!

    Thank you so much for posting your code online! I was really fascinated by your work and I presented it recently at the computer vision seminar at Stony Brook University.

    I was tinkering with the implementation of DoubleColoredMNIST because I want to use it in a related project and I noticed something strange: In the class implementation, colors are assigned to global tensors and every time an image is requested, the global color tensors get jittered instead of jittering a local copy. Is this the intended implementation or is this something unintended?

    Reproduction code:

    ds = DoubleColoredMNIST(train=True)
    
    print(ds.background_colors[5].view(1, 3), ds.object_colors[5].view(1, 3))
    # -> tensor([[0.8706, 0.7216, 0.5294]]) tensor([[0.3922, 0.5843, 0.9294]])
    
    first_image = ds[0]
    image, label = first_image['ims'], first_image['labels']
    
    print(ds.background_colors[5].view(1, 3), ds.object_colors[5].view(1, 3))
    # -> tensor([[0.8725, 0.7374, 0.5306]]) tensor([[0.3945, 0.5922, 0.9359]])
    

    If you can notice, the tensors between the two calls get changed.

    This issue comes from line 87 and line 91 in mnists/dataloader.py where I assume the intended action is to copy the color tensor. What happens instead is the local variables back_color and obj_color refer to the global color tensors and they get jittered instead.

    # line 87:
    back_color = self.background_colors[i]
    ...
    
    # line 91:
    obj_color = self.object_colors[i]
    

    If you add .clone() to the end of assignment, you will avoid this problem.

    # line 87:
    back_color = self.background_colors[i].clone()
    

    Let me know if this is actually a bug or feature of the implementation 😅

    opened by gessha 1
  • 4/5000  consulting

    4/5000 consulting

    May I ask if counterfactual data of CIFAR10 can be generated? But your code input is the Imagenet dataset, and there are some parts of the generated data code that don't know how to change? Could you give me some advice

    opened by Tancong2021 0
  • Broken google drive link

    Broken google drive link

    Hi,

    The link in scripts/download_weights.sh in line 3 is broken. The url https://drive.google.com/u/0/uc?id=1VkKexkWh5SeB8fgxAZxLKgmmvDXhVYUy&export=downloadl does not work, even when correcting downloadl to download.

    opened by thomas-w-nl 1
  • wild_MNIST/cfg.yaml hyperparameters

    wild_MNIST/cfg.yaml hyperparameters

    Hi, thank you for the nice work! I was trying to train a CFG for the wild MNIST dataset using the provided hyperparameters (cgn_wildlife_MNIST/cfg.yaml). However, the accuracy (~76%, with much more epochs) is not comparable to that of your provided model (cgn_wildlife_MNIST/weights/ckp.pth) when training the invariant classifier. I am wondering if the cfg.yaml should be further fine-tuned, or is there anything that I should pay attention to? Thank you very much!

    opened by qichenglao 0
Owner
null
Minimal PyTorch implementation of Generative Latent Optimization from the paper "Optimizing the Latent Space of Generative Networks"

Minimal PyTorch implementation of Generative Latent Optimization This is a reimplementation of the paper Piotr Bojanowski, Armand Joulin, David Lopez-

Thomas Neumann 117 Nov 27, 2022
[CVPR 2021] Released code for Counterfactual Zero-Shot and Open-Set Visual Recognition

Counterfactual Zero-Shot and Open-Set Visual Recognition This project provides implementations for our CVPR 2021 paper Counterfactual Zero-S

null 144 Dec 24, 2022
[CVPR 2021] Counterfactual VQA: A Cause-Effect Look at Language Bias

Counterfactual VQA (CF-VQA) This repository is the Pytorch implementation of our paper "Counterfactual VQA: A Cause-Effect Look at Language Bias" in C

Yulei Niu 94 Dec 3, 2022
CausaLM: Causal Model Explanation Through Counterfactual Language Models

CausaLM: Causal Model Explanation Through Counterfactual Language Models Authors: Amir Feder, Nadav Oved, Uri Shalit, Roi Reichart Abstract: Understan

Amir Feder 39 Jul 10, 2022
CARLA: A Python Library to Benchmark Algorithmic Recourse and Counterfactual Explanation Algorithms

CARLA - Counterfactual And Recourse Library CARLA is a python library to benchmark counterfactual explanation and recourse models. It comes out-of-the

Carla Recourse 200 Dec 28, 2022
[ICCV 2021] Counterfactual Attention Learning for Fine-Grained Visual Categorization and Re-identification

Counterfactual Attention Learning Created by Yongming Rao*, Guangyi Chen*, Jiwen Lu, Jie Zhou This repository contains PyTorch implementation for ICCV

Yongming Rao 90 Dec 31, 2022
The code for MM2021 paper "Multi-Level Counterfactual Contrast for Visual Commonsense Reasoning"

The Code for MM2021 paper "Multi-Level Counterfactual Contrast for Visual Commonsense Reasoning" Setting up and using the repo Get the dataset. Follow

null 4 Apr 20, 2022
StudioGAN is a Pytorch library providing implementations of representative Generative Adversarial Networks (GANs) for conditional/unconditional image generation.

StudioGAN is a Pytorch library providing implementations of representative Generative Adversarial Networks (GANs) for conditional/unconditional image generation.

null 3k Jan 8, 2023
[ICLR 2021, Spotlight] Large Scale Image Completion via Co-Modulated Generative Adversarial Networks

Large Scale Image Completion via Co-Modulated Generative Adversarial Networks, ICLR 2021 (Spotlight) Demo | Paper [NEW!] Time to play with our interac

Shengyu Zhao 373 Jan 2, 2023
Regularizing Generative Adversarial Networks under Limited Data (CVPR 2021)

Regularizing Generative Adversarial Networks under Limited Data [Project Page][Paper] Implementation for our GAN regularization method. The proposed r

Google 148 Nov 18, 2022
The pytorch implementation of DG-Font: Deformable Generative Networks for Unsupervised Font Generation

DG-Font: Deformable Generative Networks for Unsupervised Font Generation The source code for 'DG-Font: Deformable Generative Networks for Unsupervised

null 130 Dec 5, 2022
NR-GAN: Noise Robust Generative Adversarial Networks

NR-GAN: Noise Robust Generative Adversarial Networks (CVPR 2020) This repository provides PyTorch implementation for noise robust GAN (NR-GAN). NR-GAN

Takuhiro Kaneko 59 Dec 11, 2022
HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis

HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis Jungil Kong, Jaehyeon Kim, Jaekyoung Bae In our paper, we p

Rishikesh (ऋषिकेश) 31 Dec 8, 2022
Generating Anime Images by Implementing Deep Convolutional Generative Adversarial Networks paper

AnimeGAN - Deep Convolutional Generative Adverserial Network PyTorch implementation of DCGAN introduced in the paper: Unsupervised Representation Lear

Rohit Kukreja 23 Jul 21, 2022
Ï€-GAN: Periodic Implicit Generative Adversarial Networks for 3D-Aware Image Synthesis

Ï€-GAN: Periodic Implicit Generative Adversarial Networks for 3D-Aware Image Synthesis Project Page | Paper | Data Eric Ryan Chan*, Marco Monteiro*, Pe

null 375 Dec 31, 2022
Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) in PyTorch

alias-free-gan-pytorch Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) This implementation

Kim Seonghyeon 502 Jan 3, 2023
PyTorch implementations of Generative Adversarial Networks.

This repository has gone stale as I unfortunately do not have the time to maintain it anymore. If you would like to continue the development of it as

Erik Linder-Norén 13.4k Jan 8, 2023
Image Deblurring using Generative Adversarial Networks

DeblurGAN arXiv Paper Version Pytorch implementation of the paper DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks. Our netwo

Orest Kupyn 2.2k Jan 1, 2023
Code for the paper "TadGAN: Time Series Anomaly Detection Using Generative Adversarial Networks"

TadGAN: Time Series Anomaly Detection Using Generative Adversarial Networks This is a Python3 / Pytorch implementation of TadGAN paper. The associated

Arun 92 Dec 3, 2022