pythae
This library implements some of the most common (Variational) Autoencoder models. In particular it provides the possibility to perform benchmark experiments and comparisons by training the models with the same autoencoding neural network architecture. The feature make your own autoencoder allows you to train any of these models with your own data and own Encoder and Decoder neural networks.
Installation
To install the latest version of this library run the following using pip
$ pip install git+https://github.com/clementchadebec/benchmark_VAE.git
or alternatively you can clone the github repo to access to tests, tutorials and scripts.
$ git clone https://github.com/clementchadebec/benchmark_VAE.git
and install the library
$ cd benchmark_VAE
$ pip install -e .
Available Models
Below is the list of the models currently implemented in the library.
Models | Training example | Paper | Official Implementation |
---|---|---|---|
Autoencoder (AE) | |||
Variational Autoencoder (VAE) | link | ||
Beta Variational Autoencoder (Beta_VAE) | link | ||
Importance Weighted Autoencoder (IWAE) | link | link | |
Wasserstein Autoencoder (WAE) | link | link | |
Info Variational Autoencoder (INFOVAE_MMD) | link | ||
VAMP Autoencoder (VAMP) | link | link | |
Hamiltonian VAE (HVAE) | link | link | |
Regularized AE with L2 decoder param (RAE_L2) | link | link | |
Regularized AE with gradient penalty (RAE_GP) | link | link | |
Riemannian Hamiltonian VAE (RHVAE) | link |
See results for all aforementionned models
Available Samplers
Below is the list of the models currently implemented in the library.
Samplers | Models | Paper | Official Implementation |
---|---|---|---|
Normal prior (NormalSampler) | all models | link | |
Gaussian mixture (GaussianMixtureSampler) | all models | link | link |
VAMP prior sampler (VAMPSampler) | VAMP | link | link |
Manifold sampler (RHVAESampler) | RHVAE | link | |
Two stage VAE sampler (TwoStageVAESampler) | all VAE based models | link | link |
Launching a model training
To launch a model training, you only need to call a TrainingPipeline
instance.
>>> from pythae.pipelines import TrainingPipeline
>>> from pythae.models import VAE, VAEConfig
>>> from pythae.trainers import BaseTrainingConfig
>>> # Set up the training configuration
>>> my_training_config = BaseTrainingConfig(
... output_dir='my_model',
... num_epochs=50,
... learning_rate=1e-3,
... batch_size=200,
... steps_saving=None
... )
>>> # Set up the model configuration
>>> my_vae_config = model_config = VAEConfig(
... input_dim=(1, 28, 28),
... latent_dim=10
... )
>>> # Build the model
>>> my_vae_model = VAE(
... model_config=my_vae_config
... )
>>> # Build the Pipeline
>>> pipeline = TrainingPipeline(
... training_config=my_training_config,
... model=my_vae_model
... )
>>> # Launch the Pipeline
>>> pipeline(
... train_data=your_train_data, # must be torch.Tensor or np.array
... eval_data=your_eval_data # must be torch.Tensor or np.array
... )
At the end of training, the best model weights, model configuration and training configuration are stored in a final_model
folder available in my_model/MODEL_NAME_training_YYYY-MM-DD_hh-mm-ss
(with my_model
being the output_dir
argument of the BaseTrainingConfig
). If you further set the steps_saving
argument to a certain value, folders named checkpoint_epoch_k
containing the best model weights, optimizer, scheduler, configuration and training configuration at epoch k will also appear in my_model/MODEL_NAME_training_YYYY-MM-DD_hh-mm-ss
.
Lauching a training on benchmark datasets
We also provide a training script example here that can be used to train the models on benchmarks datasets (mnist, cifar10, celeba ...). The script can be launched with the following commandline
python training.py --dataset mnist --model_name ae --model_config 'configs/ae_config.json' --training_config 'configs/base_training_config.json'
See README.md for further details on this script
Launching data generation
To launch the data generation process from a trained model, you only need to build your sampler. For instance, to generate new data with your sampler, run the following.
>>> from pythae.models import VAE
>>> from pythae.samplers import NormalSampler
>>> # Retrieve the trained model
>>> my_trained_vae = VAE.load_from_folder(
... 'path/to/your/trained/model'
... )
>>> # Define your sampler
>>> my_samper = NormalSampler(
... model=my_trained_vae
... )
>>> # Generate samples
>>> gen_data = normal_samper.sample(
... num_samples=50,
... batch_size=10,
... output_dir=None,
... return_gen=True
... )
If you set output_dir
to a specific path, the generated images will be saved as .png
files named 00000000.png
, 00000001.png
... The samplers can be used with any model as long as it is suited. For instance, a GMMSampler
instance can be used to generate from any model but a VAMPSampler
will only be usable with a VAMP
model. Check here to see which ones apply to your model.
Define you own Autoencoder architecture
Pythae provides you the possibility to define your own neural networks within the VAE models. For instance, say you want to train a Wassertstein AE with a specific encoder and decoder, you can do the following:
>>> from pythae.models.nn import BaseEncoder, BaseDecoder
>>> from pythae.models.base.base_utils import ModelOuput
>>> class My_Encoder(BaseEncoder):
... def __init__(self, args=None): # Args is a ModelConfig instance
... BaseEncoder.__init__(self)
... self.layers = my_nn_layers()
...
... def forward(self, x:torch.Tensor) -> ModelOuput:
... out = self.layers(x)
... output = ModelOuput(
... embedding=out # Set the output from the encoder in a ModelOuput instance
... )
... return output
...
... class My_Decoder(BaseDecoder):
... def __init__(self, args=None):
... BaseDecoder.__init__(self)
... self.layers = my_nn_layers()
...
... def forward(self, x:torch.Tensor) -> ModelOuput:
... out = self.layers(x)
... output = ModelOuput(
... reconstruction=out # Set the output from the decoder in a ModelOuput instance
... )
... return output
...
>>> my_encoder = My_Encoder()
>>> my_decoder = My_Decoder()
And now build the model
>>> from pythae.models import WAE_MMD, WAE_MMD_Config
>>> # Set up the model configuration
>>> my_wae_config = model_config = WAE_MMD_Config(
... input_dim=(1, 28, 28),
... latent_dim=10
... )
...
>>> # Build the model
>>> my_wae_model = WAE_MMD(
... model_config=my_wae_config,
... encoder=my_encoder, # pass your encoder as argument when building the model
... decoder=my_decoder # pass your decoder as argument when building the model
... )
important note 1: For all AE-based models (AE, WAE, RAE_L2, RAE_GP), both the encoder and decoder must return a ModelOutput
instance. For the encoder, the ModelOuput
instance must contain the embbeddings under the key embedding
. For the decoder, the ModelOuput
instance must contain the reconstructions under the key reconstruction
.
important note 2: For all VAE-based models (VAE, Beta_VAE, IWAE, HVAE, VAMP, RHVAE), both the encoder and decoder must return a ModelOutput
instance. For the encoder, the ModelOuput
instance must contain the embbeddings and log-covariance matrices (of shape batch_size x latent_space_dim) respectively under the key embedding
and log_covariance
key. For the decoder, the ModelOuput
instance must contain the reconstructions under the key reconstruction
.
Using benchmark neural nets
You can also find predefined neural network architectures for the most common data sets (i.e. MNIST, CIFAR, CELEBA ...) that can be loaded as follows
>>> for pythae.models.nn.benchmark.mnist import (
... Encoder_AE_MNIST, # For AE based model (only return embeddings)
... Encoder_VAE_MNIST, # For VAE based model (return embeddings and log_covariances)
... Decoder_AE_MNIST
... )
Replace mnist by cifar or celeba to access to other neural nets.
Getting your hands on the code
To help you to understand the way pythae works and how you can train your models with this library we also provide tutorials:
-
making_your_own_autoencoder.ipynb shows you how to pass your own networks to the models implemented in pythae
-
models_training folder provides notebooks showing how to train each implemented model and how to sample from it using
pyhtae.samplers
. -
scripts folder provides in particular an example of a training script to train the models on benchmark data sets (mnist, cifar10, celeba ...)
Dealing with issues
If you are experiencing any issues while running the code or request new features/models to be implemented please open an issue on github.
🚀
Contributing You want to contribute to this library by adding a model, a sampler or simply fix a bug ? That's awesome! Thank you! Please see CONTRIBUTING.md to follow the main contributing guidelines.