Contrastively Disentangled Sequential Variational Audoencoder (C-DSVAE)
Overview
This is the implementation for our C-DSVAE, a novel self-supervised disentangled sequential representation learning method.
Requirements
- Python 3
- PyTorch 1.7
- Numpy 1.18.5
Dataset
Sprites
We provide the raw Sprites .npy
files. One can also find the dataset on a third-party repo.
For each split (train/test), we expect the following components for each sequence sample
x
: raw sample of shape [8, 3, 64, 64]c_aug
: content augmentation of shape [8, 3, 64, 64]m_aug
: motion augmentation of shape [8, 3, 64, 64]- motion factors: action (3 classes), direction (3 classes)
- content factors: skin, tops, pants, hair (each with 6 classes)
Running
Train
./run_cdsvae.sh
Test
./run_test_sprite.sh
Classification Judge
The judge classifiers are pretrained with full supervision separately.
- Sprites judge
C-DSVAE Checkpoints
We provide a sample Sprites checkpoint. Checkpoint parameters can be found in ./run_test_sprite.sh
.
Paper
If you are inspired by our work, please cite the following paper:
@inproceedings{bai2021contrastively,
title={Contrastively Disentangled Sequential Variational Autoencoder},
author={Bai, Junwen and Wang, Weiran and Gomes, Carla},
booktitle={Advances in Neural Information Processing Systems},
volume={},
year={2021}
}