This is the PyTorch implementation of GANs N’ Roses: Stable, Controllable, Diverse Image to Image Translation

Overview

GANs N' Roses Pytorch

Open In Colab This is the PyTorch implementation of GANs N’ Roses: Stable, Controllable, Diverse Image to Image Translation (works for videos too!).

Abstract:
We show how to learn a map that takes a content code, derived from a face image, and a randomly chosen style code to an anime image. We derive an adversarial loss from our simple and effective definitions of style and content. This adversarial loss guarantees the map is diverse -- a very wide range of anime can be produced from a single content code. Under plausible assumptions, the map is not just diverse, but also correctly represents the probability of an anime, conditioned on an input face. In contrast, current multimodal generation procedures cannot capture the complex styles that appear in anime. Extensive quantitative experiments support the idea the map is correct. Extensive qualitative results show that the method can generate a much more diverse range of styles than SOTA comparisons. Finally, we show that our formalization of content and style allows us to perform video to video translation without ever training on videos.

New Gradio Web Demo

Dependency

conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=<CUDA_VERSION>
pip install tqdm gdown kornia scipy opencv-python dlib moviepy lpips aubio ninja

Dataset

The dataset we use for training is the selfie2anime dataset from UGATIT. You can also use your own dataset in the following format.

└── YOUR_DATASET_NAME
   ├── trainA
       ├── xxx.jpg (name, format doesn't matter)
       ├── yyy.png
       └── ...
   ├── trainB
       ├── zzz.jpg
       ├── www.png
       └── ...
   ├── testA
       ├── aaa.jpg 
       ├── bbb.png
       └── ...
   └── testB
       ├── ccc.jpg 
       ├── ddd.png
       └── ...

Training

For training you might want to switch to train branch in order to use custom cuda kernel codes. Otherwise, it will use the Pytorch native implementation.

python train.py --name EXP_NAME --d_path YOUR_DATASET_NAME --batch BATCH_SIZE

The full model checkpoint is here if you wish to you it for finetuning etc.

Inference

Our notebook provides a comprehensive demo of both image and video translation. Pretrained model is automatically downloaded.

Citation

If you use this code or ideas from our paper, please cite our paper:

@misc{chong2021gans,
      title={GANs N' Roses: Stable, Controllable, Diverse Image to Image Translation (works for videos too!)}, 
      author={Min Jin Chong and David Forsyth},
      year={2021},
      eprint={2106.06561},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Acknowledgments

This code borrows heavily from StyleGAN2 by rosalinity and partly from UGATIT.

Comments
  • What is the training performance?

    What is the training performance?

    This is great work!

    I've been looking at many similar models for a business problem. Yours looks most promising, but I've noticed that information about the training performance was not included in the paper. How well does it perform, and would you consider it "efficient" on hardware?

    Thanks, Tyler

    opened by idiomic 6
  • Questions about shuffle style

    Questions about shuffle style

    Excellent work!I would like to make it work on mobilephone. But,When I read the code(train.py), I have two questions:

    1. why there need to shuffle origin img batch: A = aug(ori_A[[np.random.randint(args.batch)]].expand_as(ori_A)) B = aug(ori_B[[np.random.randint(args.batch)]].expand_as(ori_B))

    2. Won't the shuffle of style lead to mismatch? fake_A2B2A(c1, s1) != A(c1, s2), Cycle Consistency Loss maybe unsatisfied? fake_A2B2A = G_B2A.decode(A2B2A_content, shuffle_batch(A2B_style)) fake_B2A2B = G_A2B.decode(B2A2B_content, shuffle_batch(B2A_style))

    opened by crwsr124 3
  • about choosing latent dimension

    about choosing latent dimension

    Hi, thanks for your great work!

    I have a question about the latent dimension setting. Is there any specific reason to choose latent dimension to 8?

    The base setting of stylegan's latent dimension is 512. And in just my opinion, dimension 8 is not enough to embed style. Could you explain why you set it up like that?

    opened by LeeDongYeun 2
  • About part code of the test class in Train.py

    About part code of the test class in Train.py

    Thanks for your great work . there are some code in the test class :

     if i % 2 == 0:
                    A2B_mod1 = torch.randn([1, args.latent_dim]).cuda()
                    B2A_mod1 = torch.randn([1, args.latent_dim]).cuda()
                    A2B_mod2 = torch.randn([1, args.latent_dim]).cuda()
                    B2A_mod2 = torch.randn([1, args.latent_dim]).cuda()
    

    i want to konw the meaning of the random sampling. And why only set these variable when i % 2 == 0?

    opened by diaodeyi 2
  • About D_L in the code

    About D_L in the code

    Thanks for the great work, the code seems didn't add D_L to train()?

    def train(args, trainA_loader, trainB_loader, testA_loader, testB_loader, G_A2B, G_B2A, D_A, D_B, G_optim, D_optim, device):
        G_A2B.train(), G_B2A.train(), D_A.train(), D_B.train() 
    
    opened by diaodeyi 2
  • Full model for transfer learning?

    Full model for transfer learning?

    The model released in the notebook is only 300MB and seems to only include the final A2B and B2A ema generators, so it can only be used for inference.

    Is it possible to get ahold of a full pretrained checkpoint for fine tuning/transfer learning? Retraining from scratch will take several weeks on my hardware.

    opened by arfafax 2
  • Fix README link

    Fix README link

    Previous link to Replicate web demo in README.md file had a typo; fixed. And please claim the model here: https://replicate.com/mchong6/gans-n-roses to make it public! Thanks!

    opened by vccheng2001 1
  • Add Docker environment & web demo

    Add Docker environment & web demo

    Hey @mchong6 ! 👋 This pull request uses an open source tool called Cog to make GANsNRoses more accessible to others. Our team here at Replicate has created a web demo where other people can try out your model! View it here: https://replicate.com/mchong6/gans-n-roses We've added some examples to the web demo; please claim your page here so you own it/edit it, and we'll feature it on our website and tweet about it too. For your model, you can either input an image or a video (which uses the 'normal' mode). I'm Vivian from Replicate, where we're trying to make machine learning reproducible by implementing CV/DL models we like. Let me know if you have any questions/feedback!

    opened by vccheng2001 1
  • memory consumption

    memory consumption

    hello, It seems that it takes a huge amount of memory to train this network, first I tried to train this network on 2080Ti with 12GB memory, it always crashed and said CUDA out of memory error, then I tried on v100 and it worked. It takes about 17GB memory to train this network, is it too much?

    opened by LeRoii 1
  • train with smaller input size mat1 dim 1 must match mat2 dim 0

    train with smaller input size mat1 dim 1 must match mat2 dim 0

    I am using a Quadro T2000 with 4 GB Memory and tryed to train with batch size 1 and input size 128.

    I get the following error:

    /home/user/anaconda3/lib/python3.7/site-packages/torch/utils/cpp_extension.py:3: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative
    uses
      import imp
    Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
    Loading model from: /home/user/anaconda3/lib/python3.7/site-packages/lpips/weights/v0.1/vgg.pth
      0%|                                                                                                                                                                           | 0/300000 [00:00<?, ?it/s]/home/user/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py:3063: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
      "See the documentation of nn.Upsample for details.".format(mode))
      0%|                                                                                                                                                                           | 0/300000 [00:00<?, ?it/s]
    Traceback (most recent call last):
      File "train.py", line 465, in <module>
        train(args, trainA_loader, trainB_loader, testA_loader, testB_loader, G_A2B, G_B2A, D_A, D_B, G_optim, D_optim, device)
      File "train.py", line 167, in train
        A2B_content, A2B_style = G_A2B.encode(A)
      File "/home/user/mnt/develoment/code/GANsNRoses/model.py", line 501, in encode
        return self.encoder(input)
      File "/home/user/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
        result = self.forward(*input, **kwargs)
      File "/home/user/mnt/develoment/code/GANsNRoses/model.py", line 703, in forward
        style = self.style(act)
      File "/home/user/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
        result = self.forward(*input, **kwargs)
      File "/home/user/anaconda3/lib/python3.7/site-packages/torch/nn/modules/container.py", line 117, in forward
        input = module(input)
      File "/home/user/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
        result = self.forward(*input, **kwargs)
      File "/home/user/mnt/develoment/code/GANsNRoses/model.py", line 179, in forward
        out = F.linear(input, self.weight * self.scale)
      File "/home/user/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py", line 1692, in linear
        output = input.matmul(weight.t())
    RuntimeError: mat1 dim 1 must match mat2 dim 0
    
    
    opened by guitar9 1
  • about pretrained mode to finetune

    about pretrained mode to finetune

    Hello, I try to use the pretrained model that you provided to finetune, but when I use train script to set the ckpt, it tell me like this: 2021/07/08 20:26:17 File "train.py", line 377, in Loading model from: /opt/conda/lib/python3.7/site-packages/lpips/weights/v0.1/vgg.pth 182021/07/08 20:26:17 G_A2B.load_state_dict(ckpt['G_A2B']) 192021/07/08 20:26:17 Unexpected key(s) in state_dict: "encoder.stem.5.conv1.0.weight", "encoder.stem.5.conv1.1.bias", "encoder.stem.5.conv2.0.weight", "encoder.stem.5.conv2.1.bias", "encoder.stem.4.skip.0.kernel", "encoder.stem.4.skip.1.weight", "encoder.stem.4.conv2.2.bias", "encoder.stem.4.conv2.0.kernel", "encoder.stem.4.conv2.1.weight", "encoder.style.3.weight", "encoder.style.3.bias", "convs.6.conv.weight", "convs.6.conv.blur.kernel", "convs.6.conv.modulation.weight", "convs.6.conv.modulation.bias", "convs.6.activate.bias", "convs.7.conv.weight", "convs.7.conv.modulation.weight", "convs.7.conv.modulation.bias", "convs.7.activate.bias", "to_rgbs.3.bias", "to_rgbs.3.upsample.kernel", "to_rgbs.3.conv.weight", "to_rgbs.3.conv.modulation.weight", "to_rgbs.3.conv.modulation.bias". 202021/07/08 20:26:17 size mismatch for convs.0.conv.weight: copying a param with shape torch.Size([1, 512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 512, 3, 3]). 212021/07/08 20:26:17 RuntimeError: Error(s) in loading state_dict for Generator: 222021/07/08 20:26:17 size mismatch for encoder.style.4.weight: copying a param with shape torch.Size([8, 512]) from checkpoint, the shape in current model is torch.Size([512, 8192]). 232021/07/08 20:26:17 size mismatch for convs.0.activate.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). 242021/07/08 20:26:17 size mismatch for convs.1.conv.modulation.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). 252021/07/08 20:26:17 size mismatch for convs.2.conv.modulation.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]). 262021/07/08 20:26:17 size mismatch for convs.1.conv.modulation.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]). 272021/07/08 20:26:17 size mismatch for convs.2.conv.weight: copying a param with shape torch.Size([1, 256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 256, 3, 3]). 282021/07/08 20:26:17 size mismatch for convs.3.conv.weight: copying a param with shape torch.Size([1, 256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 128, 3, 3]). 292021/07/08 20:26:17 size mismatch for convs.2.activate.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]). 302021/07/08 20:26:17 size mismatch for convs.3.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]). 312021/07/08 20:26:17 size mismatch for convs.4.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]). 322021/07/08 20:26:17 size mismatch for convs.5.conv.weight: copying a param with shape torch.Size([1, 128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 64, 64, 3, 3]). 332021/07/08 20:26:17 size mismatch for convs.5.activate.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]). 342021/07/08 20:26:17 size mismatch for convs.3.activate.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]). 352021/07/08 20:26:17 size mismatch for to_rgbs.0.conv.weight: copying a param with shape torch.Size([1, 3, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 256, 1, 1]). 362021/07/08 20:26:17 size mismatch for convs.5.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]). 372021/07/08 20:26:17 size mismatch for convs.4.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]). 382021/07/08 20:26:17 size mismatch for to_rgbs.1.conv.modulation.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]). 392021/07/08 20:26:17 size mismatch for to_rgbs.0.conv.modulation.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). 402021/07/08 20:26:17 size mismatch for to_rgbs.2.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]). 412021/07/08 20:26:17 size mismatch for to_rgbs.1.conv.weight: copying a param with shape torch.Size([1, 3, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 128, 1, 1]). 422021/07/08 20:26:17 size mismatch for to_rgbs.2.conv.weight: copying a param with shape torch.Size([1, 3, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 64, 1, 1]). 432021/07/08 20:26:17 size mismatch for convs.2.conv.modulation.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). 442021/07/08 20:26:17 category=DeprecationWarning, 452021/07/08 20:26:17 Missing key(s) in state_dict: "encoder.stem.4.conv2.0.weight", "encoder.stem.4.conv2.1.bias", "encoder.style.2.0.kernel", "encoder.style.2.1.weight", "encoder.style.2.2.bias", "encoder.style.5.weight", "encoder.style.5.bias". 462021/07/08 20:26:17 size mismatch for convs.3.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]). 472021/07/08 20:26:17 size mismatch for convs.4.conv.weight: copying a param with shape torch.Size([1, 128, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 64, 128, 3, 3]). 482021/07/08 20:26:17 size mismatch for convs.1.conv.weight: copying a param with shape torch.Size([1, 512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 256, 3, 3]). 492021/07/08 20:26:17 size mismatch for convs.1.activate.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). 502021/07/08 20:26:17 size mismatch for to_rgbs.1.conv.modulation.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]). 512021/07/08 20:26:17 size mismatch for to_rgbs.2.conv.modulation.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]). 522021/07/08 20:26:17 File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1045, in load_state_dict 532021/07/08 20:26:17 size mismatch for to_rgbs.0.conv.modulation.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]). 542021/07/08 20:26:17 Traceback (most recent call last): 552021/07/08 20:26:17 size mismatch for encoder.style.4.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([512]). 562021/07/08 20:26:17 size mismatch for convs.4.activate.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]). 572021/07/08 20:26:17 size mismatch for convs.5.conv.modulation.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]). how can i fix this? maybe strict=False?

    opened by 523997931 1
  • Have you compared the cartoon and Disney filter effects in SnapChat?

    Have you compared the cartoon and Disney filter effects in SnapChat?

    The Disney filter in Snapchat is very very very stable for video input, but I have no idea how they make it so stable. Since your method works for videos too, I wonder if you have any clue?

    opened by lastapple 0
  • RuntimeError: AUBIO ERROR: source_wavread: Failed opening ./samples/dsm.mp4 (could not find RIFF header)

    RuntimeError: AUBIO ERROR: source_wavread: Failed opening ./samples/dsm.mp4 (could not find RIFF header)

    Hi, when I run in " https://colab.research.google.com/github/mchong6/GANsNRoses/blob/main/inference_colab.ipynb " is alright!

    But, run in my machine, error "RuntimeError: AUBIO ERROR: source_wavread: Failed opening ./samples/dsm.mp4 (could not find RIFF header)"

    And I had try: https://github.com/aubio/aubio/issues/111 Unfortunately, after these commands I still get the same error. Could you share the detail about Virtual environment setup. Or how to fix?Thanks

    opened by DWCTOD 3
Owner
null
Stable Neural ODE with Lyapunov-Stable Equilibrium Points for Defending Against Adversarial Attacks

Stable Neural ODE with Lyapunov-Stable Equilibrium Points for Defending Against Adversarial Attacks Stable Neural ODE with Lyapunov-Stable Equilibrium

Kang Qiyu 8 Dec 12, 2022
Commonality in Natural Images Rescues GANs: Pretraining GANs with Generic and Privacy-free Synthetic Data - Official PyTorch Implementation (CVPR 2022)

Commonality in Natural Images Rescues GANs: Pretraining GANs with Generic and Privacy-free Synthetic Data (CVPR 2022) Potentials of primitive shapes f

null 31 Sep 27, 2022
The FIRST GANs-based omics-to-omics translation framework

OmiTrans Please also have a look at our multi-omics multi-task DL freamwork ?? : OmiEmbed The FIRST GANs-based omics-to-omics translation framework Xi

Xiaoyu Zhang 6 Dec 14, 2022
A Fast and Stable GAN for Small and High Resolution Imagesets - pytorch

A Fast and Stable GAN for Small and High Resolution Imagesets - pytorch The official pytorch implementation of the paper "Towards Faster and Stabilize

Bingchen Liu 455 Jan 8, 2023
PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.

DLR-RM 4.7k Jan 1, 2023
The official pytorch implemention of the CVPR paper "Temporal Modulation Network for Controllable Space-Time Video Super-Resolution".

This is the official PyTorch implementation of TMNet in the CVPR 2021 paper "Temporal Modulation Network for Controllable Space-Time VideoSuper-Resolu

Gang Xu 95 Oct 24, 2022
This is the official Pytorch implementation of the paper "Diverse Motion Stylization for Multiple Style Domains via Spatial-Temporal Graph-Based Generative Model"

Diverse Motion Stylization (Official) This is the official Pytorch implementation of this paper. Diverse Motion Stylization for Multiple Style Domains

Soomin Park 28 Dec 16, 2022
Official PyTorch implementation of the paper "TEMOS: Generating diverse human motions from textual descriptions"

TEMOS: TExt to MOtionS Generating diverse human motions from textual descriptions Description Official PyTorch implementation of the paper "TEMOS: Gen

Mathis Petrovich 187 Dec 27, 2022
Official implementation of FCL-taco2: Fast, Controllable and Lightweight version of Tacotron2 @ ICASSP 2021

FCL-Taco2: Towards Fast, Controllable and Lightweight Text-to-Speech synthesis (ICASSP 2021) Paper | Demo Block diagram of FCL-taco2, where the decode

Disong Wang 39 Sep 28, 2022
An implementation for `Text2Event: Controllable Sequence-to-Structure Generation for End-to-end Event Extraction`

Text2Event An implementation for Text2Event: Controllable Sequence-to-Structure Generation for End-to-end Event Extraction Please contact Yaojie Lu (@

Roger 153 Jan 7, 2023
This is the official implementation of the paper "Object Propagation via Inter-Frame Attentions for Temporally Stable Video Instance Segmentation".

[CVPRW 2021] - Object Propagation via Inter-Frame Attentions for Temporally Stable Video Instance Segmentation

Anirudh S Chakravarthy 6 May 3, 2022
The source code of the ICCV2021 paper "PIRenderer: Controllable Portrait Image Generation via Semantic Neural Rendering"

The source code of the ICCV2021 paper "PIRenderer: Controllable Portrait Image Generation via Semantic Neural Rendering"

Ren Yurui 261 Jan 9, 2023
The source code of the ICCV2021 paper "PIRenderer: Controllable Portrait Image Generation via Semantic Neural Rendering"

Website | ArXiv | Get Start | Video PIRenderer The source code of the ICCV2021 paper "PIRenderer: Controllable Portrait Image Generation via Semantic

Ren Yurui 81 Sep 25, 2021
A Multi-attribute Controllable Generative Model for Histopathology Image Synthesis

A Multi-attribute Controllable Generative Model for Histopathology Image Synthesis This is the pytorch implementation for our MICCAI 2021 paper. A Mul

Jiarong Ye 7 Apr 4, 2022
Implementation of Diverse Semantic Image Synthesis via Probability Distribution Modeling

Diverse Semantic Image Synthesis via Probability Distribution Modeling (CVPR 2021) Paper Zhentao Tan, Menglei Chai, Dongdong Chen, Jing Liao, Qi Chu,

tzt 45 Nov 17, 2022
[CVPR 2021] Pytorch implementation of Hijack-GAN: Unintended-Use of Pretrained, Black-Box GANs

Hijack-GAN: Unintended-Use of Pretrained, Black-Box GANs In this work, we propose a framework HijackGAN, which enables non-linear latent space travers

Hui-Po Wang 46 Sep 5, 2022
A PyTorch implementation of ViTGAN based on paper ViTGAN: Training GANs with Vision Transformers.

ViTGAN: Training GANs with Vision Transformers A PyTorch implementation of ViTGAN based on paper ViTGAN: Training GANs with Vision Transformers. Refer

Hong-Jia Chen 127 Dec 23, 2022
PyTorch implementation of Progressive Growing of GANs for Improved Quality, Stability, and Variation.

PyTorch implementation of Progressive Growing of GANs for Improved Quality, Stability, and Variation. Warning: the master branch might collapse. To ob

null 559 Dec 14, 2022
PyTorch implementation of VAGAN: Visual Feature Attribution Using Wasserstein GANs

PyTorch implementation of VAGAN: Visual Feature Attribution Using Wasserstein GANs This code aims to reproduce results obtained in the paper "Visual F

Orobix 93 Aug 17, 2022