PyTorch implementation of the TTC algorithm

Overview

Trust-the-Critics

This repository is a PyTorch implementation of the TTC algorithm and the WGAN misalignment experiments presented in Trust the Critics: Generatorless and Multipurpose WGANs with Initial Convergence Guarantees.

How to run this code

  • Create a Python virtual environment with Python 3.8 installed.
  • Install the necessary Python packages listed in the requirements.txt file (this can be done through pip install -r /path/to/requirements.txt).

In the example_shell_scripts folder, we include samples of shell scripts we used to run our experiments. We note that training generative models is computationally demanding, and thus requires adequate computational resources (i.e. running this on your laptop is not recommended).

TTC algorithm

The various experiments we run with TTC are described in Section 5 and Addendix B of the paper. Illustrating the flexibility of the TTC algorithm, the image generation, denoising and translation experiments can all be run using the ttc.py script; the only necessary changes are the source and target datasets. Running TTC with a given source and a given target will train and save several critic neural networks that can subsequently be used to push the source distribution towards the target distribution by applying the 'steptaker' function found in TTC_utils/steptaker.py once for each critic.

Necessary arguments for ttc.py are:

  • 'source' : The name of the distribution or dataset that is to be pushed towards the target (options are listed in ttc.py).
  • 'target' : The name of the target dataset (options are listed in ttc.py).
  • 'data' : The path of a directory where the necessary data is located. This includes the target dataset, in a format that can be accessed by a dataloader object obtained from the corresponding function in dataloader.py. Such a dataloader always belongs to the torch.utils.data.DataLoader class (e.g. if target=='mnist', then the corresponding dataloader will be an instance of torchvision.datasets.MNIST, and the MNIST dataset should be placed in 'data' in a way that reflects this). If the source is a dataset, it needs to be placed in 'data' as well. If source=='untrained_gen', then the untrained generator used to create the source distribution needs to be saved under 'data/ugen.pth'.
  • 'temp_dir' : The path of a directory where the trained critics will be saved, along with a few other files (including the log.pkl file that contains the step sizes). Despite the name, this folder isn't necessarily temporary.

Other optional arguments are described in a commented section at the top of the ttc.py script. Note that running ttc.py will only train the critics that the TTC algorithm uses to push the source distribution towards the target distribution, it will not actually push any samples from the source towards the target (as mentioned above, this is done using the steptaker function).

TTC image generation
For a generative experiment, run ttc.py with the source argument set to either 'noise' or 'untrained_gen' and the target of your choice. Then, run ttc_eval.py, which will use the saved critics and step sizes to push noise inputs towards the target distribution according to the TTC algorithm (using the steptaker function), and which will optionally evaluate generative performance with FID and/or MMD (FID is used in the paper). The arguments 'source', 'target', 'data', 'temp_dir' and 'model' for ttc_eval.py should be set to the same values as when running ttc.py. If evaluating FID, the folder specified by 'temp_dir' should contain a subdirectory named 'temp_dir/{target}test' (e.g. 'temp_dir/mnisttest' if target=='mnist') containing the test data from the target dataset saved as individual files. For instance, this folder could contain files of the form '00001.jpg', '00002.jpg', etc. (although extensions other than .jpg can be used).

TTC denoising
For a denoising experiment, run ttc.py with source=='noisybsds500' and target=='bsds500' (specifying a noise level with the 'sigma' argument). Then, run denoise_eval.py (with the same 'temp_dir', 'data' and 'model' arguments), which will add noise to images, denoise them using the TTC algorithm and the saved critics, and evaluate PSNR's.

TTC Monet translation
For a denoising experiment, run ttc.py with source=='photo' and target=='monet'. Then run ttc_eval.py (with the same 'source', 'target', 'temp_dir', 'data' and 'model' arguments, and presumably with no FID or MMD evaluation), which will sample realistic images from the source and make them look like Monet paintings.

WGAN misalignment

The WGAN misalignment experiments are described in Section 3 and Appendix B.1 of the paper, and are run using misalignments.py. This script trains a WGAN while, at some iterations, measuring how misaligned the movement of generated samples caused by updating the generator is from the critic's gradient. The generator's FID is also measured at the same iterations.

The required arguments for misalignments.py are:

  • 'target' : The dataset used to train the WGAN - can be either 'mnist' or 'fashion' (for Fashion-MNIST).
  • 'data' : The path of a folder where the MNIST (or Fashion-MNIST) dataset is located, in a format that can be accessed by an instance of the torchvision.datasets.MNIST class (resp torchvision.datasets.FashionMNIST).
  • 'fid_data' : The path of a folder containing the test data from the MNIST dataset saved as individual files. For instance, this folder could contain files of the form '00001.jpg', '00002.jpg', etc. (although extensions other than .jpg can be used).
  • 'checkpoints' : A string of integers separated by underscores. The integers specify the iterations at which misalignments and FID are computed, and training will continue until the largest iteration is reached.

Other optional arguments (including 'results_path' and 'temp_dir') are described in a commented section at the top of the misalignments.py. The misalignment results reported in the paper (Tables 1 and 5, and Figure 3), correspond to using the default hyperparameters and to setting the 'checkpoints' argument roughly equal to '10_25000_40000', with '10' corresponding the early stage in training, '25000' to the mid stage, and '40000' to the late stage.

WGAN generation

For completeness we include the code that was used to obtain the WGAN FID statistics in Table 3 of the paper, which includes the wgan_gp.py and wgan_gp_eval.py scripts. The former trains a WGAN with the InfoGAN architecture on the dataset specified by the 'target' argument, saving generator model dictionaries in the folder specified by 'temp_dir' at ten equally spaced stages in training. The wgan_gp_eval.py script evaluates the performance of the generator with the different model dictionaries in 'temp_dir'.

The necessary arguments to run wgan_gp.py are:

  • 'target' : The name of the dataset to generate (can be either 'mnist', 'fashion' or 'cifar10').
  • 'data' : Folder where the dataset is located.
  • 'temp_dir' : Folder where the model dictionaries are saved.

Once wgan_gp.py has run, wgan_gp_eval.py should be called with the same arguments for 'target', 'data' and 'temp_dir', and setting the 'model' argument to 'infogan'. If evaluating FID, the 'temp_dir' folder needs to contain the test data from the target dataset saved as individual files. For instance, this folder could contain files of the form '00001.jpg', '00002.jpg', etc. (although extensions other than .jpg can be used).

Reproducibility

This repository contains two branches: 'main' and 'reproducible'. You are currectly viewing the 'main' branch, which contains a clean version of the code meant to be easy to read and interpret and to run more efficiently than the version on the 'reproducible' branch. The results obtained by running the code on this branch should be nearly (but not perfectly) identical to the results stated in the papers, the differences stemming from the randomness inherent to the experiments. The 'reproducible' branch allows one to replicate exactly the results stated in the paper (random seeds are specified) for the TTC experiments.

Computing architecture and running times

We ran different versions of the code presented here on Compute Canada (https://www.computecanada.ca/) clusters, always using a single NVIDIA V100 Volta or NVIDIA A100 Ampere GPU. Here are rough estimations of the running times for our experiments.

  • MNIST/Fashion MNIST generation training run (TTC): 60-90 minutes.
  • MNIST/Fashion MNIST generation training run (WGAN): 45-90 minutes (this includes misalignments computations).
  • CIFAR10 generation training run: 3-4 hours (TTC), 90 minutes (WGAN-GP).
  • Image translation training run: up to 20 hours.
  • Image denoising training run: 8-10 hours.

Assets

Portions of this code, as well as the datasets used to produce our experimental results, make use of existing assets. We provide here a list of all assets used, along with the licenses under which they are distributed, if specified by the originator:

You might also like...
Re-implementation of the Noise Contrastive Estimation algorithm for pyTorch, following "Noise-contrastive estimation: A new estimation principle for unnormalized statistical models." (Gutmann and Hyvarinen, AISTATS 2010)

Noise Contrastive Estimation for pyTorch Overview This repository contains a re-implementation of the Noise Contrastive Estimation algorithm, implemen

ppo_pytorch_cpp - an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch
ppo_pytorch_cpp - an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch

PPO Pytorch C++ This is an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch. It uses a simple TestEnvironment t

PyTorch implementation of DreamerV2 model-based RL algorithm

PyDreamer Reimplementation of DreamerV2 model-based RL algorithm in PyTorch. The official DreamerV2 implementation can be found here. Features ... Run

PyTorch implementation of the implicit Q-learning algorithm (IQL)
PyTorch implementation of the implicit Q-learning algorithm (IQL)

Implicit-Q-Learning (IQL) PyTorch implementation of the implicit Q-learning algorithm IQL (Paper) Currently only implemented for online learning. Offl

PyTorch Implementation of the SuRP algorithm by the authors of the AISTATS 2022 paper "An Information-Theoretic Justification for Model Pruning"

PyTorch Implementation of the SuRP algorithm by the authors of the AISTATS 2022 paper "An Information-Theoretic Justification for Model Pruning".

A pytorch reprelication of the model-based reinforcement learning algorithm MBPO
A pytorch reprelication of the model-based reinforcement learning algorithm MBPO

Overview This is a re-implementation of the model-based RL algorithm MBPO in pytorch as described in the following paper: When to Trust Your Model: Mo

An algorithm that handles large-scale aerial photo co-registration, based on SURF, RANSAC and PyTorch autograd.
An algorithm that handles large-scale aerial photo co-registration, based on SURF, RANSAC and PyTorch autograd.

An algorithm that handles large-scale aerial photo co-registration, based on SURF, RANSAC and PyTorch autograd.

Implements pytorch code for the Accelerated SGD algorithm.

AccSGD This is the code associated with Accelerated SGD algorithm used in the paper On the insufficiency of existing momentum schemes for Stochastic O

PyGAD, a Python 3 library for building the genetic algorithm and training machine learning algorithms (Keras & PyTorch).
PyGAD, a Python 3 library for building the genetic algorithm and training machine learning algorithms (Keras & PyTorch).

PyGAD: Genetic Algorithm in Python PyGAD is an open-source easy-to-use Python 3 library for building the genetic algorithm and optimizing machine lear

Owner
null
PyTorch implementation of neural style transfer algorithm

neural-style-pt This is a PyTorch implementation of the paper A Neural Algorithm of Artistic Style by Leon A. Gatys, Alexander S. Ecker, and Matthias

null 770 Jan 2, 2023
PyTorch implementation of DeepDream algorithm

neural-dream This is a PyTorch implementation of DeepDream. The code is based on neural-style-pt. Here we DeepDream a photograph of the Golden Gate Br

null 121 Nov 5, 2022
PyTorch implementation of Algorithm 1 of "On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models"

Code for On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models This repository will reproduce the main results from our pape

Mitch Hill 32 Nov 25, 2022
A Pytorch implementation of the multi agent deep deterministic policy gradients (MADDPG) algorithm

Multi-Agent-Deep-Deterministic-Policy-Gradients A Pytorch implementation of the multi agent deep deterministic policy gradients(MADDPG) algorithm This

Phil Tabor 159 Dec 28, 2022
An unofficial PyTorch implementation of a federated learning algorithm, FedAvg.

Federated Averaging (FedAvg) in PyTorch An unofficial implementation of FederatedAveraging (or FedAvg) algorithm proposed in the paper Communication-E

Seok-Ju Hahn 123 Jan 6, 2023
Official PyTorch implementation for FastDPM, a fast sampling algorithm for diffusion probabilistic models

Official PyTorch implementation for "On Fast Sampling of Diffusion Probabilistic Models". FastDPM generation on CIFAR-10, CelebA, and LSUN datasets. S

Zhifeng Kong 68 Dec 26, 2022
Author's PyTorch implementation of Randomized Ensembled Double Q-Learning (REDQ) algorithm.

REDQ source code Author's PyTorch implementation of Randomized Ensembled Double Q-Learning (REDQ) algorithm. Paper link: https://arxiv.org/abs/2101.05

null 109 Dec 16, 2022
PyTorch implementation of our Adam-NSCL algorithm from our CVPR2021 (oral) paper "Training Networks in Null Space for Continual Learning"

Adam-NSCL This is a PyTorch implementation of Adam-NSCL algorithm for continual learning from our CVPR2021 (oral) paper: Title: Training Networks in N

Shipeng Wang 34 Dec 21, 2022
A PyTorch implementation of "Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" (KDD 2019).

ClusterGCN ⠀⠀ A PyTorch implementation of "Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" (KDD 2019). A

Benedek Rozemberczki 697 Dec 27, 2022
Pytorch implementation of the DeepDream computer vision algorithm

deep-dream-in-pytorch Pytorch (https://github.com/pytorch/pytorch) implementation of the deep dream (https://en.wikipedia.org/wiki/DeepDream) computer

null 102 Dec 5, 2022