SMIT: Stochastic Multi-Label Image-to-image Translation
This repository provides a PyTorch implementation of SMIT. SMIT can stochastically translate an input image to multiple domains using only a single generator and a discriminator. It only needs a target domain (binary vector e.g., [0,1,0,1,1] for 5 different domains) and a random gaussian noise.
Paper
SMIT: Stochastic Multi-Label Image-to-image Translation
Andrés Romero 1, Pablo Arbelaez1, Luc Van Gool 2, Radu Timofte 2
1 Biomedical Computer Vision (BCV) Lab, Universidad de Los Andes.
2 Computer Vision Lab (CVL), ETH Zürich.
Citation
@article{romero2019smit,
title={SMIT: Stochastic Multi-Label Image-to-Image Translation},
author={Romero, Andr{\'e}s and Arbel{\'a}ez, Pablo and Van Gool, Luc and Timofte, Radu},
journal={ICCV Workshops},
year={2019}
}
Dependencies
Usage
Cloning the repository
$ git clone https://github.com/BCV-Uniandes/SMIT.git
$ cd SMIT
Downloading the dataset
To download the CelebA dataset:
$ bash generate_data/download.sh
Train command:
./main.py --GPU=$gpu_id --dataset_fake=CelebA
Each dataset must has datasets/
and datasets/
files. All models and figures will be stored at snapshot/models/$dataset_fake/
and snapshot/samples/$dataset_fake/
, respectivelly.
Test command:
./main.py --GPU=$gpu_id --dataset_fake=CelebA --mode=test
SMIT will expect the .pth
weights are stored at snapshot/models/$dataset_fake/
(or --pretrained_model=location/model.pth should be provided). If there are several models, it will take the last alphabetical one.
Demo:
./main.py --GPU=$gpu_id --dataset_fake=CelebA --mode=test --DEMO_PATH=location/image_jpg/or/location/dir
DEMO performs transformation per attribute, that is swapping attributes with respect to the original input as in the images below. Therefore, --DEMO_LABEL is provided for the real attribute if DEMO_PATH is an image (If it is not provided, the discriminator acts as classifier for the real attributes).
Pretrained models
Models trained using Pytorch 1.0.
Multi-GPU
For multiple GPUs we use Horovod. Example for training with 4 GPUs:
mpirun -n 4 ./main.py --dataset_fake=CelebA
Qualitative Results. Multi-Domain Continuous Interpolation.
First column (original input) -> Last column (Opposite attributes: smile, age, genre, sunglasses, bangs, color hair). Up: Continuous interpolation for the fake image. Down: Continuous interpolation for the attention mechanism.