Hierarchical Few-Shot Generative Models
This repo contains code and experiments for the paper Hierarchical Few-Shot Generative Models
.
Settings
Clone the repo:
git clone https://github.com/georgosgeorgos/hierarchical-few-shot-generative-models
cd hierarchical-few-shot-generative-models
Create and activate the conda env:
conda env create -f environment.yml
conda activate hfsgm
The code has been tested on Ubuntu 18.04, Python 3.6 and CUDA 11.3
We use wandb
for visualization. The first time you run the code you will need to login.
Data
We provide preprocessed Omniglot dataset.
From the main folder, copy the data in data/omniglot_ns/
:
wget https://github.com/georgosgeorgos/hierarchical-few-shot-generative-models/releases/download/Omniglot/omni_train_val_test.pkl
For CelebA you need to download the dataset from here.
Dataset
In dataset
we provide utilities to process and augment datasets in the few-shot setting. Each dataset is a large collection of small sets. Sets can be created dynamically. The dataset/base.py
file collects basic info about the datasets. For binary datasets (omniglot_ns.py
) we augment using flipping and rotations. For RGB datasets (celeba.py
) we use only flipping.
Experiment
In experiment
we implement scripts for model evaluation, experiments and visualizations.
attention.py
- visualize attention weights and heads for models with learnable aggregations (LAG).cardinality.py
- compute ELBOs for different input set size: [1, 2, 5, 10, 20].classifier_mnist.py
- few-shot classifiers on MNIST.kl_layer.py
- compute KL over z and c for each layer in latent space.marginal.py
- compute approximate log-marginal likelihood with 1K importance samples.refine_vis.py
- visualize refined samples.sampling_rgb.py
- reconstruction, conditional, refined, unconditional sampling for RGB datasets.sampling_transfer.py
- reconstruction, conditional, refined, unconditional sampling on transfer datasets.sampling.py
- reconstruction, conditional, refined, unconditional sampling for binary datasets.transfer.py
- compute ELBOs on MNIST, DoubleMNIST, TripleMNIST.
Model
In model
we implement baselines and model variants.
base.py
- base class for all the models.vae.py
- Variational Autoencoder (VAE).ns.py
- Neural Statistician (NS).tns.py
- NS with learnable aggregation (NS-LAG).cns.py
- NS with convolutional latent space (CNS).ctns.py
- CNS with learnable aggregation (CNS-LAG).hfsgm.py
- Hierarchical Few-Shot Generative Model (HFSGM).thfsgm.py
- HFSGM with learnable aggregation (HFSGM-LAG).chfsgm.py
- HFSGM with convolutional latent space (CHFSGM).cthfsgm.py
- CHFSGM with learnable aggregation (CHFSGM-LAG).
Script
Scripts used for training the models in the paper.
To run a CNS on Omniglot:
sh script/main_cns.sh GPU_NUMBER omniglot_ns
Train a model
To train a generic model run:
python main.py --name {VAE, NS, CNS, CTNS, CHFSGM, CTHFSGM} \
--model {vae, ns, cns, ctns, chfsgm, cthfsgm} \
--augment \
--dataset omniglot_ns \
--likelihood binary \
--hidden-dim 128 \
--c-dim 32 \
--z-dim 32 \
--output-dir /output \
--alpha-step 0.98 \
--alpha 2 \
--adjust-lr \
--scheduler plateau \
--sample-size {2, 5, 10} \
--sample-size-test {2, 5, 10} \
--num-classes 1 \
--learning-rate 1e-4 \
--epochs 400 \
--batch-size 100 \
--tag (optional string)
If you do not want to save logs, use the flag --dry_run
. This flag will call utils/trainer_dry.py
instead of trainer.py
.
Acknowledgments
A lot of code and ideas borrowed from: