Pytorch implementation of VQVAE.
This paper combines 2 tricks:
- Vector Quantization (check out this amazing blog for better understanding.)
- Straight-Through (It solves the problem of back-propagation through discrete latent variables, which are intractable.)
This model has a neural network encoder and decoder, and a prior just like the vanila Variational AutoEncoder(VAE). But this model also has a latent embedding space called codebook
(size: K x D). Here, K is the size of latent space and D is the dimension of each embedding e.
In vanilla variational autoencoders, the output from the encoder z(x) is used to parameterize a Normal/Gaussian distribution, which is sampled from to get a latent representation z of the input x using the 'reparameterization trick'. This latent representation is then passed to the decoder. However, In VQVAEs, z(x) is used as a "key" to do nearest neighbour lookup into the embedding codebook c, and get zq(x), the closest embedding in the space. This is called Vector Quantization(VQ)
operation. Then, zq(x) is passed to the decoder, which reconstructs the input x. The decoder can either parameterize p(x|z) as the mean of Normal distribution using a transposed convolution layer like in vannila VAE, or it can autoregressively generate categorical distribution over [0,255] pixel values like PixelCNN. In this project, the first approach is used.
The loss function is combined of 3 components:
Regular Reconstruction loss
Vector Quantization loss
Commitment loss
Vector Quantization loss encourages the items in the codebook to move closer to the encoder output ||sg[ze(x) - e||^2]
and Commitment loss encourages the output of the encoder to be close to embedding it picked, to commit to its codebook embedding. ||ze(x) - sg[e]]||^2
. commitment loss is multiplied with a constant beta, which is 1.0 for this project. Here, sg means "stop-gradient". Which means we don't propagate the gradients with respect to that term.
Results:
The Model is trained on MNIST and CIFAR10 datasets.
👉
Reconstructed Image
Target Details:
- Trained models for MNIST and CIFAR10 are in the
Trained models
directory. - Hidden size of the bottleneck(z) for MNIST and CIFAR10 is 128, 256 respectively.