PyTorch implementation of "Learning to Discover Cross-Domain Relations with Generative Adversarial Networks"

Overview

DiscoGAN in PyTorch

PyTorch implementation of Learning to Discover Cross-Domain Relations with Generative Adversarial Networks.

* All samples in README.md are genearted by neural network except the first image for each row.
* Network structure is slightly diffferent (here) from the author's code.

Requirements

Usage

First download datasets (from pix2pix) with:

$ bash ./data/download_dataset.sh dataset_name

or you can use your own dataset by placing images like:

data
├── YOUR_DATASET_NAME
│   ├── A
│   |   ├── xxx.jpg (name doesn't matter)
│   |   ├── yyy.jpg
│   |   └── ...
│   └── B
│       ├── zzz.jpg
│       ├── www.jpg
│       └── ...
└── download_dataset.sh

All images in each dataset should have same size like using imagemagick:

# for Ubuntu
$ sudo apt-get install imagemagick
$ mogrify -resize 256x256! -quality 100 -path YOUR_DATASET_NAME/A/*.jpg
$ mogrify -resize 256x256! -quality 100 -path YOUR_DATASET_NAME/B/*.jpg

# for Mac
$ brew install imagemagick
$ mogrify -resize 256x256! -quality 100 -path YOUR_DATASET_NAME/A/*.jpg
$ mogrify -resize 256x256! -quality 100 -path YOUR_DATASET_NAME/B/*.jpg

# for scale and center crop
$ mogrify -resize 256x256^ -gravity center -crop 256x256+0+0 -quality 100 -path ../A/*.jpg

To train a model:

$ python main.py --dataset=edges2shoes --num_gpu=1
$ python main.py --dataset=YOUR_DATASET_NAME --num_gpu=4

To test a model (use your load_path):

$ python main.py --dataset=edges2handbags --load_path=logs/edges2handbags_2017-03-18_10-55-37 --num_gpu=0 --is_train=False

Results

1. Toy dataset

Result of samples from 2-dimensional Gaussian mixture models. IPython notebook

# iteration: 0:

# iteration: 10000:

2. Shoes2handbags dataset

# iteration: 11200:

x_A -> G_AB(x_A) -> G_BA(G_AB(x_A)) (shoe -> handbag -> shoe)

x_B -> G_BA(x_B) -> G_AB(G_BA(x_B)) (handbag -> shoe -> handbag)

x_A -> G_AB(x_A) -> G_BA(G_AB(x_A)) -> G_AB(G_BA(G_AB(x_A))) -> G_BA(G_AB(G_BA(G_AB(x_A)))) -> ...

3. Edges2shoes dataset

# iteration: 9600:

x_A -> G_AB(x_A) -> G_BA(G_AB(x_A)) (color -> sketch -> color)

x_B -> G_BA(x_B) -> G_AB(G_BA(x_B)) (sketch -> color -> sketch)

x_A -> G_AB(x_A) -> G_BA(G_AB(x_A)) -> G_AB(G_BA(G_AB(x_A))) -> G_BA(G_AB(G_BA(G_AB(x_A)))) -> ...

4. Edges2handbags dataset

# iteration: 9500:

x_A -> G_AB(x_A) -> G_BA(G_AB(x_A)) (color -> sketch -> color)

x_B -> G_BA(x_B) -> G_AB(G_BA(x_B)) (sketch -> color -> sketch)

x_A -> G_AB(x_A) -> G_BA(G_AB(x_A)) -> G_AB(G_BA(G_AB(x_A))) -> G_BA(G_AB(G_BA(G_AB(x_A)))) -> ...

5. Cityscapes dataset

# iteration: 8350:

x_B -> G_BA(x_B) -> G_AB(G_BA(x_B)) (image -> segmentation -> image)

x_A -> G_AB(x_A) -> G_BA(G_AB(x_A)) (segmentation -> image -> segmentation)

6. Map dataset

# iteration: 22200:

x_B -> G_BA(x_B) -> G_AB(G_BA(x_B)) (image -> segmentation -> image)

x_A -> G_AB(x_A) -> G_BA(G_AB(x_A)) (segmentation -> image -> segmentation)

7. Facades dataset

Generation and reconstruction on dense segmentation dataset looks weird which are not included in the paper.
I guess a naive choice of mean square error loss for reconstruction need some change on this dataset.

# iteration: 19450:

x_B -> G_BA(x_B) -> G_AB(G_BA(x_B)) (image -> segmentation -> image)

x_A -> G_AB(x_A) -> G_BA(G_AB(x_A)) (segmentation -> image -> segmentation)

Related works

Author

Taehoon Kim / @carpedm20

Comments
  • Inconsistent tensor sizes error with own data

    Inconsistent tensor sizes error with own data

    I tried out my own dataset in data/mydata (with the A and B folders) but I get the following error:

    ~/DiscoGAN-pytorch$ python main.py --dataset=mydata --num_gpu=1
    [*] MODEL dir: logs/mydata_2017-03-21_14-49-25
    [*] PARAM path: logs/mydata_2017-03-21_14-49-25/params.json
    Traceback (most recent call last):
      File "main.py", line 41, in <module>
        main(config)
      File "main.py", line 33, in main
        trainer.train()
      File "/home/bart/DiscoGAN-pytorch/trainer.py", line 161, in train
        valid_x_A, valid_x_B = self._get_variable(A_loader.next()), self._get_variable(B_loader.next())
      File "/home/bart/anaconda3/envs/Python36/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 174, in __next__
        return self._process_next_batch(batch)
      File "/home/bart/anaconda3/envs/Python36/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 198, in _process_next_batch
        raise batch.exc_type(batch.exc_msg)
    RuntimeError: Traceback (most recent call last):
      File "/home/bart/anaconda3/envs/Python36/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 32, in _worker_loop
        samples = collate_fn([dataset[i] for i in batch_indices])
      File "/home/bart/anaconda3/envs/Python36/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 68, in default_collate
        return torch.stack(batch, 0)
      File "/home/bart/anaconda3/envs/Python36/lib/python3.6/site-packages/torch/functional.py", line 56, in stack
        return torch.cat(list(t.unsqueeze(dim) for t in sequence), dim)
    RuntimeError: inconsistent tensor sizes at /data/users/soumith/miniconda2/conda-bld/pytorch-0.1.10_1488755368782/work/torch/lib/TH/generic/THTensorMath.c:2548
    
    opened by bartolsthoorn 9
  • Multi GPU

    Multi GPU

    Thanks for the PyTorch implementation of DiscoGAN. I am having troubles running it with multiple GPUs though. I think the first problem is in the config.py file, where it casts the num_gpu argument to a bool instead of an int. https://github.com/carpedm20/DiscoGAN-pytorch/blob/2feae821dcea201b5461db58e2438af97f11cb63/config.py#L53 But even fixing that, the code will only run when --num_gpu=1 and not when --num_gpu=2. Here's the traceback error:

    Error:

    Traceback (most recent call last):
      File "main.py", line 41, in <module>
        main(config)
      File "main.py", line 33, in main
        trainer.train()
      File "/home/***/Documents/DiscoGAN-pytorch-master/trainer.py", line 187, in train
        x_AB = self.G_AB(x_A).detach()
      File "/home/***/.pyenv/versions/2.7.13/lib/python2.7/site-packages/torch/nn/modules/module.py", line 206, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/***/Documents/DiscoGAN-pytorch-master/models.py", line 45, in forward
        return nn.parallel.data_parallel(self.main, x, gpu_ids)
      File "/home/***/.pyenv/versions/2.7.13/lib/python2.7/site-packages/torch/nn/parallel/data_parallel.py", line 101, in data_parallel
        replicas = replicate(module, device_ids[:len(inputs)])
      File "/home/***/.pyenv/versions/2.7.13/lib/python2.7/site-packages/torch/nn/parallel/replicate.py", line 10, in replicate
        params = list(network.parameters())
    AttributeError: 'function' object has no attribute 'parameters'
    

    Any idea what could be the cause?

    bug help wanted 
    opened by mehdi-shiba 4
  • Using a target size (torch.Size([4])) that is different to the input size (torch.Size([4, 169])) is deprecated. Please ensure they have the same size.

    Using a target size (torch.Size([4])) that is different to the input size (torch.Size([4, 169])) is deprecated. Please ensure they have the same size.

    (base) D:\DiscoGAN>python main.py --dataset=siys2simk --input_scale_size 256 --batch_size 4 --a_grayscale True --b_grayscale True --num_worker 1 --num_gpu=0
    [*] MODEL dir: ./logs\siys2simk_2022-08-26_11-17-52
    [*] PARAM path: ./logs\siys2simk_2022-08-26_11-17-52\params.json
      0%|                                                                                                                                                                                         | 0/500000 [00:00<?, ?it/s]
    Traceback (most recent call last):
      File "D:\DiscoGAN\main.py", line 41, in <module>
        main(config)
      File "D:\DiscoGAN\main.py", line 33, in main
        trainer.train()
      File "D:\DiscoGAN\trainer.py", line 200, in train
        l_d_A_real, l_d_A_fake = bce(self.D_A(x_A).squeeze(1), real_tensor), bce(self.D_A(x_BA).squeeze(1), fake_tensor)
      File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\loss.py", line 613, in forward
        return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
      File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\functional.py", line 3074, in binary_cross_entropy
        raise ValueError(
    ValueError: Using a target size (torch.Size([4])) that is different to the input size (torch.Size([4, 169])) is deprecated. Please ensure they have the same size.
    
    opened by typeface-cn 0
  • How to start training??

    How to start training??

    Hello Sir,

    I downloaded your code and maps-dataset.. And when I started I met some error.

    ...
    Traceback (most recent call last):
      File "main.py", line 41, in <module>
        main(config)
      File "main.py", line 33, in main
        trainer.train()
      File "/itsme/TESTBOARD/additional_networks/GAN/pytorch_DiscoGAN_carpedm20/trainer.py", line 247, in train
        format(step, self.max_step, l_d.data[0], l_g.data[0]))
    IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number
    

    Thanks..

    opened by edwardcho 1
  •  It seems that the discriminator cannot adapt to the size of the input image.

    It seems that the discriminator cannot adapt to the size of the input image.

    Once I set '-- input_scale_size' not to 64, an error will be reported. It seems that the discriminator cannot adapt to the size of the input image. How did you solve the problem? "ValueError: Target and input must have the same number of elements. target nelement (2) != input nelement (338)"

    opened by SwordHolderSH 1
  • torch size

    torch size

    getting this message:

    0%| | 1/50000 [00:03<55:08:42, 3.97s/it]C:\Users\Shadow\Anaconda3\lib\site-packages\torch\nn\modules\loss.py:512: UserWarning: Using a target size (torch.Size([200])) that is different to the input size (torch.Size([200, 1])) is deprecated. Please ensure they have the same size.

    opened by jhaseon 1
  • DiscoGAN paper said that they don't need paired data unlike Conditional GAN. but...

    DiscoGAN paper said that they don't need paired data unlike Conditional GAN. but...

    DiscoGAN paper they assert that "to avoid costly pairing, we address the task of discovering cross-domain relations given unpaired data" But, I don't know why this implementation of DiscoGAN requires paired data? (as far as I know/see, in the dataset downloaded ) Or, is it just to simplify the input process?

    opened by timothy-jhang 0
Owner
Taehoon Kim
ex OpenAI
Taehoon Kim
ALBERT-pytorch-implementation - ALBERT pytorch implementation

ALBERT-pytorch-implementation developing... 모델의 개념이해를 돕기 위한 구현물로 현재 변수명을 상세히 적었고

BG Kim 3 Oct 6, 2022
An essential implementation of BYOL in PyTorch + PyTorch Lightning

Essential BYOL A simple and complete implementation of Bootstrap your own latent: A new approach to self-supervised Learning in PyTorch + PyTorch Ligh

Enrico Fini 48 Sep 27, 2022
RealFormer-Pytorch Implementation of RealFormer using pytorch

RealFormer-Pytorch Implementation of RealFormer using pytorch. Includes comparison with classical Transformer on image classification task (ViT) wrt C

Simo Ryu 90 Dec 8, 2022
A PyTorch implementation of the paper Mixup: Beyond Empirical Risk Minimization in PyTorch

Mixup: Beyond Empirical Risk Minimization in PyTorch This is an unofficial PyTorch implementation of mixup: Beyond Empirical Risk Minimization. The co

Harry Yang 121 Dec 17, 2022
A pytorch implementation of Pytorch-Sketch-RNN

Pytorch-Sketch-RNN A pytorch implementation of https://arxiv.org/abs/1704.03477 In order to draw other things than cats, you will find more drawing da

Alexis David Jacq 172 Dec 12, 2022
PyTorch implementation of Advantage async actor-critic Algorithms (A3C) in PyTorch

Advantage async actor-critic Algorithms (A3C) in PyTorch @inproceedings{mnih2016asynchronous, title={Asynchronous methods for deep reinforcement lea

LEI TAI 111 Dec 8, 2022
Pytorch-diffusion - A basic PyTorch implementation of 'Denoising Diffusion Probabilistic Models'

PyTorch implementation of 'Denoising Diffusion Probabilistic Models' This reposi

Arthur Juliani 76 Jan 7, 2023
Fang Zhonghao 13 Nov 19, 2022
RETRO-pytorch - Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch

RETRO - Pytorch (wip) Implementation of RETRO, Deepmind's Retrieval based Attent

Phil Wang 556 Jan 4, 2023
HashNeRF-pytorch - Pure PyTorch Implementation of NVIDIA paper on Instant Training of Neural Graphics primitives

HashNeRF-pytorch Instant-NGP recently introduced a Multi-resolution Hash Encodin

Yash Sanjay Bhalgat 616 Jan 6, 2023
Generic template to bootstrap your PyTorch project with PyTorch Lightning, Hydra, W&B, and DVC.

NN Template Generic template to bootstrap your PyTorch project. Click on Use this Template and avoid writing boilerplate code for: PyTorch Lightning,

Luca Moschella 520 Dec 30, 2022
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intention of Apex is to make up-to-date utilities available to users as quickly as possible.

NVIDIA Corporation 6.9k Jan 3, 2023
Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch

30 Days Of Machine Learning Using Pytorch Objective of the repository is to learn and build machine learning models using Pytorch. List of Algorithms

Mayur 119 Nov 24, 2022
Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pytorch Lightning 1.4k Jan 1, 2023
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

Mamy Ratsimbazafy 360 Dec 10, 2022
The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.

This is a curated list of tutorials, projects, libraries, videos, papers, books and anything related to the incredible PyTorch. Feel free to make a pu

Ritchie Ng 9.2k Jan 2, 2023
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

Mamy Ratsimbazafy 359 Jan 5, 2023
A bunch of random PyTorch models using PyTorch's C++ frontend

PyTorch Deep Learning Models using the C++ frontend Gettting started Clone the repo 1. https://github.com/mrdvince/pytorchcpp 2. cd fashionmnist or

Vince 0 Jul 13, 2021
PyTorch Autoencoders - Implementing a Variational Autoencoder (VAE) Series in Pytorch.

PyTorch Autoencoders Implementing a Variational Autoencoder (VAE) Series in Pytorch. Inspired by this repository Model List check model paper conferen

Subin An 8 Nov 21, 2022