This is the official source code for SLATE. We provide the code for the model, the training code and a dataset loader for the 3D Shapes dataset. This code is implemented in Pytorch.
The current release provides a boilerplate code to train the model on the 3D Shapes dataset. The dataset class is provided in
shapes_3d.py. You can edit or replace this class if you need to run the code on a different dataset. The 3D Shapes dataset can be downloaded from the official URL https://console.cloud.google.com/storage/browser/3d-shapes. This should produce a dataset file
3dshapes.h5. During training, the path to this dataset file needs to be provided using the argument
To train the model, simply execute:
train.py to see the full list of training arguments.
The training code produces Tensorboard logs. To see these logs, run Tensorboard on the logging directory that was provided in the training argument
--log_path. These logs contain the training loss curves and visualizations of reconstructions and object attention maps.
Hyperparameters of Interest
- Learning Rate can be tuned using the training argument
--lr_mainand different choices can affect the characteristics of the object attention maps.
- Number of Slots can be tuned using the training argument
--num_slots. Number of slots should be set higher than the number of objects you expect to see in the images.
- Number of Slot Attention Iterations can be tuned using the training argument
--num_iterations. In general, keep the number of iterations as small as possible because too many iterations can prevent slots from learning to diversify and attach to different objects.
This repository provides the following files.
train.pycontains the main code for running the training.
slate.pyprovides the model class for SLATE.
shapes_3d.pycontains the dataset class for 3D Shapes dataset.
dvae.pyprovides the encoder and the decoder for Discrete VAE.
slot_attn.pyprovides the model class for Slot Attention encoder.
transformer.pyprovides the model classes for Transformer.
utils.pyprovides helper classes and functions for the implementation.