Code for the paper: Sketch Your Own GAN

Overview

Sketch Your Own GAN

Project | Paper | Youtube

Our method takes in one or a few hand-drawn sketches and customizes an off-the-shelf GAN to match the input sketch. While our new model changes an object’s shape and pose, other visual cues such as color, texture, background, are faithfully preserved after the modification.


Sheng-Yu Wang1, David Bau2, Jun-Yan Zhu1.
CMU1, MIT CSAIL2
In ICCV, 2021.

Training code, evaluation code, and datasets will be released soon.

Results

Our method can customize a pre-trained GAN to match input sketches.

Interpolation using our customized models. Latent space interpolation is smooth with our customized models.

Image 1
Interoplation
Image 2

Image editing using our customized models. Given a real image (a), we project it to the original model's latent space z using Huh et al. (b). (c) We then feed the projected z to the our standing cat model trained on sketches. (d) Finally, we showed edit the image with add fur operation using GANSpace.

Failure case. Our method is not capable of generating images to match the Attneave’s cat sketch or the horse sketch by Picasso. We note that Attneave’s cat depicts a complex pose, and Picasso’s sketches are drawn with a distinctive style, both of which make our method struggle.

Getting Started

Clone our repo

git clone [email protected]:PeterWang512/GANSketching.git
cd GANSketching

Install packages

  • Install PyTorch (version >= 1.6.0) (pytorch.org)
    pip install -r requirements.txt

Download model weights

  • Run bash weights/download_weights.sh

Generate samples from a customized model

This command runs the customized model specified by ckpt, and generates samples to save_dir.

# generates samples from the "standing cat" model.
python generate.py --ckpt weights/photosketch_standing_cat_noaug.pth --save_dir output/samples_standing_cat

# generates samples from the cat face model in Figure. 1 of the paper.
python generate.py --ckpt weights/by_author_cat_aug.pth --save_dir output/samples_teaser_cat

Latent space edits by GANSpace

Our model preserves the latent space editability of the original model. Our models can apply the same edits using the latents reported in Härkönen et.al. (GANSpace).

# add fur to the standing cats
python ganspace.py --obj cat --comp_id 27 --scalar 50 --layers 2,4 --ckpt weights/photosketch_standing_cat_noaug.pth --save_dir output/ganspace_fur_standing_cat

# close the eyes of the standing cats
python ganspace.py --obj cat --comp_id 45 --scalar 60 --layers 5,7 --ckpt weights/photosketch_standing_cat_noaug.pth --save_dir output/ganspace_eye_standing_cat

Acknowledgments

This repository borrows partially from SPADE, stylegan2-pytorch, PhotoSketch, GANSpace, and data-efficient-gans.

Reference

If you find this useful for your research, please cite the following work.

@inproceedings{wang2021sketch,
  title={Sketch Your Own GAN},
  author={Wang, Sheng-Yu and Bau, David and Zhu, Jun-Yan},
  booktitle={Proceedings of the IEEE International Conference on Computer Vision},
  year={2021}
}

Feel free to contact us with any comments or feedback.

Comments
  • A TypeError occurred while running training scripts.

    A TypeError occurred while running training scripts.

    When I was running bash scripts/train_teaser_cat.sh in my Linux server, I got a TypeError. My virtual environment in conda already installed: Python 3.7.12, pytorch 1.7.1, cudatoolkit 10.1.243, torchvision 0.8.2, and ninja 1.10.2. I already tried higher version of pytorch 1.8.1, but it has another problem of subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status. So I returned to the pytorch 1.7.1 environment. I have been troubled by the TypeError problem for several days. Can someone help me solve it? Thanks.

    Logs in detail:

    (test3) yuanWang@server-TiTan:~/GANtest1$ bash scripts/train_teaser_cat.sh
    Traceback (most recent call last):
      File "train.py", line 6, in <module>
        from eval import Evaluator
      File "/data4/yuanWang/GANtest1/eval/__init__.py", line 1, in <module>
        from .evaluation import Evaluator
      File "/data4/yuanWang/GANtest1/eval/evaluation.py", line 7, in <module>
        from run_metrics import get_vgg_features, make_eval_images
      File "/data4/yuanWang/GANtest1/run_metrics.py", line 13, in <module>
        from training.networks.stylegan2 import Generator
      File "/data4/yuanWang/GANtest1/training/networks/__init__.py", line 2, in <module>
        from .misc import *
      File "/data4/yuanWang/GANtest1/training/networks/misc.py", line 1, in <module>
        from . import stylegan2
      File "/data4/yuanWang/GANtest1/training/networks/stylegan2.py", line 8, in <module>
        from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
      File "/data4/yuanWang/GANtest1/training/networks/op/__init__.py", line 1, in <module>
        from .fused_act import FusedLeakyReLU, fused_leaky_relu
      File "/data4/yuanWang/GANtest1/training/networks/op/fused_act.py", line 15, in <module>
        os.path.join(module_path, "fused_bias_act_kernel.cu"),
      File "/data4/yuanWang/anaconda3/envs/test3/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 1091, in load
        keep_intermediates=keep_intermediates)
      File "/data4/yuanWang/anaconda3/envs/test3/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 1269, in _jit_compile
        is_standalone=is_standalone,
    TypeError: bump_version_if_changed() got an unexpected keyword argument 'is_python_module'
    
    opened by 0sure 4
  • How to further modify the generate model through GANspace

    How to further modify the generate model through GANspace

    How can I use GANspace to adjust the image generated by the net_G? For example, the original generator can generate cats pictures indoor but I want the background to be outdoor grassland. Is such an operation possible? If possible, how can I operate through GANspace? Thank you very much for your help!

    opened by 0sure 2
  • Segmentation fault (Core dumped) when loading the compiled module fused

    Segmentation fault (Core dumped) when loading the compiled module fused

    Hello, thank you for your valuable advice. Now, I have solved the compilation problem about the fused module under the environment of pytorch **1.7.1**. Now I've got the required .so file:

    (test4torch) yuanWang@server-TiTan:~/.cache/torch_extensions/fused$ ls
    build.ninja  fused_bias_act_kernel.cuda.o  fused_bias_act.o  fused.so
    

    I encountered a new problem when I tried to use _import_module_from_libraryto load the compiled module fused, I changed the beginning code in fused_act.py to:

    try:
        user_home_path = os.path.expanduser('~')
        fused = _import_module_from_library('fused', user_home_path+'/.cache/torch_extensions/fused', True)
        print(f'Load fused from {user_home_path}/.cache/torch_extensions/fused')
        print("Load success!")
    except:
        module_path = os.path.dirname(__file__)
        fused = load(
            name='fused',
            sources=[
                os.path.join(module_path, 'fused_bias_act.cpp'),
                os.path.join(module_path, 'fused_bias_act_kernel.cu'),
            ],
            verbose=True
        )
        print(f'Load function used. Build fused from cpp & cu files')
    

    Executing the code step by step will not report an error, but when I type quit (), the command line will report an error Segmentation fault (Core dumped). My gdb traceback content shows:

    (gdb) run fused_act.py
    Starting program: /data4/yuanWang/anaconda3/envs/test4torch/bin/python fused_act.py
    [Thread debugging using libthread_db enabled]
    Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
    [New Thread 0x7fff7d446700 (LWP 4153)]
    Load fused from /data4/yuanWang/.cache/torch_extensions/fused
    Load success!
    
    Thread 1 "python" received signal SIGSEGV, Segmentation fault.
    0x00007fffdeafa9b8 in ?? ()
       from /data4/yuanWang/anaconda3/envs/test4torch/lib/python3.7/site-packages/torch/lib/../../../../libcudart.so.10.1
    (gdb) backtrace
    #0  0x00007fffdeafa9b8 in ?? ()
       from /data4/yuanWang/anaconda3/envs/test4torch/lib/python3.7/site-packages/torch/lib/../../../../libcudart.so.10.1
    #1  0x00007fffdeafb1a3 in ?? ()
       from /data4/yuanWang/anaconda3/envs/test4torch/lib/python3.7/site-packages/torch/lib/../../../../libcudart.so.10.1
    #2  0x00007fffdeafb8a5 in ?? ()
       from /data4/yuanWang/anaconda3/envs/test4torch/lib/python3.7/site-packages/torch/lib/../../../../libcudart.so.10.1
    #3  0x00007ffff7806031 in __run_exit_handlers (status=0, 
        listp=0x7ffff7bae718 <__exit_funcs>, 
        run_list_atexit=run_list_atexit@entry=true, run_dtors=run_dtors@entry=true)
        at exit.c:108
    #4  0x00007ffff780612a in __GI_exit (status=<optimized out>) at exit.c:139
    #5  0x00007ffff77e4c8e in __libc_start_main (main=0x555555645ab0 <main>, argc=2, 
        argv=0x7fffffffe158, init=<optimized out>, fini=<optimized out>, 
        rtld_fini=<optimized out>, stack_end=0x7fffffffe148)
        at ../csu/libc-start.c:344
    #6  0x000055555572b73d in _start () at ../sysdeps/x86_64/elf/start.S:103
    

    Do you know how to solve this problem? By the way, do you choose to use Windows as the operating system to run code? I strongly suspect that this problem is related to the Ubuntu operating system. Thank you for your time. Best wishes.

    opened by 0sure 2
  • Evaluation FID in trained model doesn't match the one from the pre-trained model

    Evaluation FID in trained model doesn't match the one from the pre-trained model

    Hello, thank you for the code available,

    I executed the training script:

    bash scripts/train_photosketch_horse_riders.sh

    and then I compared the metrics of the trained model with those of the file:

    weights/photosketch_horse_riders_aug.pth

    using the evaluation script:

    run_metrics.py

    with the command:

    python run_metrics.py --models_list weights/eval_list --output metric_results.csv

    I added the trained weights into the weights folder and into the eval_list file so that I could check its FID too.

    I very different FIDs for the photosketch_horse_riders_aug.pth (FID ~20.13) and the trained one (FID ~37.07). I assumed that the training script would generate a model similar to the one stored at photosketch_horse_riders_aug.pth.

    There is some other different procedure to get a trained model with the same FID as photosketch_horse_riders_aug.pth, which is very close but not quite the same as in the paper (FID ~19.94)?

    opened by arturandre 2
  • How to load the parameters of the latest model and continue training.

    How to load the parameters of the latest model and continue training.

    I found that there are codes in train.py that automatically saves the latest model after a certain number of iterations. When I stop training manually, how can I continue the unfinished training with the saved model? Thanks for your help!

    opened by 0sure 1
  • Google Colab for Examples

    Google Colab for Examples

    Not exactly an issue.

    I created a Google Colab Notebook that sets up the repo and calls the example commands in the readme. I did this because of issues I was having on my local setup. Hopefully this could be added to give an easier quick start for those interested in the project.

    Colab Link

    opened by bionboy 1
  • Problems about fused_act.py. TypeError: bump_version_if_changed() got an unexpected keyword argument 'is_python_module' and subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

    Problems about fused_act.py. TypeError: bump_version_if_changed() got an unexpected keyword argument 'is_python_module' and subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

    When I was runningfused_act.py, I got a bug about TypeError. I am using a Linux server without root and sudo permission, and my virtual environment in conda is: Python 3.7, pytorch 1.7.1, torchvision 0.8.2, cudatoolkit V10.1.243, gcc version 7.5.0 (Ubuntu 7.5.0-3ubuntu1~18.04) and ninja 1.10.2. I used conda install -c conda-forge ninja to install ninja but not sure about ninja's working status. I already changed line 1631 in myName/anaconda3/envs/test1/lib/python3.7/site-packages/torch/utils/cpp_extension.py, from ['ninja', '-v'] to ['ninja', '-version'] to avoid subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status. After that, I got the TypeError problem. Could anybody help me please? Logs in detail: Traceback (most recent call last): (in virtual environment pytorch 1.7.1) File "fused_act.py", line 15, in os.path.join(module_path, "fused_bias_act_kernel.cu"), File "/data4/yuanWang/anaconda3/envs/test3/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 1091, in load keep_intermediates=keep_intermediates) File "/data4/yuanWang/anaconda3/envs/test3/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 1269, in _jit_compile is_standalone=is_standalone, TypeError: bump_version_if_changed() got an unexpected keyword argument 'is_python_module'

    opened by 0sure 0
  • Training process

    Training process

    I was trying to train for the first time using the example:

    "python train.py --name cat_train_5 --batch 16 --dataroot_sketch ./data/sketch/by_author/cat --dataroot_image ./data/image/cat --l_image 0.7 - -g_pretrained ./pretrained/stylegan2-cat/netG.pth --d_pretrained ./pretrained/stylegan2-cat/netD.pth --max_iter 150000 --disable_eval --diffaug_policy translation --no_wandb"

    And everything seems to be fine:

    Using pretrained weight for D1... Using pretrained weight for D_image... ----------------- Options --------------- batch: 16 [default: 4] beta1: 0.0 beta2: 0.99 channel_multiplier: 2 checkpoints_dir: checkpoint d_pretrained: ./pretrained/stylegan2-cat/netD.pth [default: ] d_reg_every: 16 dataroot_image: ./data/image/cat [default: None] dataroot_sketch: ./data/sketch/by_author/cat [default: None] diffaug_policy: translation [default: ] disable_eval: True [default: False] display_freq: 2500 display_winsize: 400 dsketch_no_pretrain: False eval_batch: 50 eval_dir: None eval_freq: 5000 g_pretrained: ./pretrained/stylegan2-cat/netG.pth [default: ] gan_mode: softplus isTrain: True [default: None] l_image: 0.7 [default: 0] l_weight: 0 latent_avg_samples: 8192 lr: 0.002 lr_mlp: 0.01 max_epoch: 1000000 max_iter: 150000 [default: 75001] mixing: 0.9 n_mlp: 8 name: cat_train_5 [default: None] no_d_regularize: False no_html: False no_wandb: True [default: False] optim_param_g: style photosketch_path: ./pretrained/photosketch.pth print_freq: 100 r1: 10 reduce_visuals: False resume_iter: None save_freq: 2500 size: 256 sketch_channel: 1 transform_fake: toSketch,to3ch transform_real: to3ch use_cpu: False z_dim: 512 ----------------- End -------------------

    -------------- Trainables --------------- (G trainable parameters) style.1.weight style.1.bias style.2.weight style.2.bias style.3.weight style.3.bias style.4.weight style.4.bias style.5.weight style.5.bias style.6.weight style.6.bias style.7.weight style.7.bias style.8.weight style.8.bias ----------------- End ------------------- create web directory checkpoint\cat_train_5\web... Training was successfully finished.

    I expected that the weights would be generated in the checkpoint folder and then used to generate the images, but I only found two files "ops" and "log_loss" which is empty.

    I am running this on Windows 10.

    The process is correct? have I forgotten something? Could it be that the problem is OS?

    I'm new working with gans, so any help will be appreciated.

    opened by ananas1178 1
  • ImportError: No module named 'fused'

    ImportError: No module named 'fused'

    Hi Wang,

    Thanks for sharing your code.

    I have installed all the prerequisite packages. When I try to compile the module 'fused', errors below occur. I removed all cache file in torch_extensions folder by using rm -rf before every building tests but still got the same error. By the way I already changed '['ninja', '-v']' to '['ninja', '--version']' in cpp_extension.py under anaconda3/envs/CTtest1/lib/python3.7/site-packages/torch/utils/ to avoid another error. Could you please tell me how to fix it?

    Detected CUDA files, patching ldflags
    Emitting ninja build file /home/slhua/.cache/torch_extensions/fused/build.ninja...
    Building extension module fused...
    Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
    1.7.2
    Loading extension module fused...
    Traceback (most recent call last):
      File "fused_act.py", line 25, in <module>
        is_python_module = True)
      File "/home/slhua/anaconda3/envs/CTtest1/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 997, in load
        keep_intermediates=keep_intermediates)
      File "/home/slhua/anaconda3/envs/CTtest1/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 1213, in _jit_compile
        return _import_module_from_library(name, build_directory, is_python_module)
      File "/home/slhua/anaconda3/envs/CTtest1/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 1560, in _import_module_from_library
        file, path, description = imp.find_module(module_name, [path])
      File "/home/slhua/anaconda3/envs/CTtest1/lib/python3.7/imp.py", line 296, in find_module
        raise ImportError(_ERR_MSG.format(name), name=name)
    ImportError: No module named 'fused'
    

    My environment: cudatoolkit 10.2.89, cudnn 7.6.5, python 3.7.6, pytorch 1.7.1, torchvision 0.8.2, ninja 1.7.2

    opened by UangBell 1
  • How to transform my own sketch to latent z ?

    How to transform my own sketch to latent z ?

    Hi! Glad to see your work! But I have a question, as follows.

    Consider practical usage: step 1: I make a cat sketch image by hand. step 2: transform the sketch image to latent_z. step 3: feed latent_z to netG network to get a cat image.

    I am surprised how to realize step2 ? Do you mean that I need netG, photo2Sketch network, and use pix2latent method? Or only need netG and use pix2latent method? If I just use netG to get z, it will still generate the cat sketch, but not cat image. Is not it ?

    Thanks!

    opened by hahaCui 4
  • How to generate using your own sketches?

    How to generate using your own sketches?

    Is it possible to generate cats by providing my own sketches? Looking at the code it seems that all sketches are hardcoded, but maybe I'm missing something.

    opened by adelorenz 6
Owner
null
A code repository associated with the paper A Benchmark for Rough Sketch Cleanup by Chuan Yan, David Vanderhaeghe, and Yotam Gingold from SIGGRAPH Asia 2020.

A Benchmark for Rough Sketch Cleanup This is the code repository associated with the paper A Benchmark for Rough Sketch Cleanup by Chuan Yan, David Va

null 33 Dec 18, 2022
Using this codebase as a tool for my own research. Making some modifications to the original repo for my own purposes.

For SwapNet Create a list.txt file containing all the images to process. This can be done with the GNU find command: find path/to/input/folder -name '

Andrew Jong 2 Nov 10, 2021
PyTorch implementation of CVPR 2020 paper (Reference-Based Sketch Image Colorization using Augmented-Self Reference and Dense Semantic Correspondence) and pre-trained model on ImageNet dataset

Reference-Based-Sketch-Image-Colorization-ImageNet This is a PyTorch implementation of CVPR 2020 paper (Reference-Based Sketch Image Colorization usin

Yuzhi ZHAO 11 Jul 28, 2022
FuseDream: Training-Free Text-to-Image Generationwith Improved CLIP+GAN Space OptimizationFuseDream: Training-Free Text-to-Image Generationwith Improved CLIP+GAN Space Optimization

FuseDream This repo contains code for our paper (paper link): FuseDream: Training-Free Text-to-Image Generation with Improved CLIP+GAN Space Optimizat

XCL 191 Dec 31, 2022
DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time

DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time Introduction This is official implementation for DR-GAN (IEEE TCS

Kang Liao 18 Dec 23, 2022
A sketch extractor for anime/illustration.

Anime2Sketch Anime2Sketch: A sketch extractor for illustration, anime art, manga By Xiaoyu Xiang Updates 2021.5.2: Upload more example results of anim

Xiaoyu Xiang 1.6k Jan 1, 2023
[CVPR 21] Vectorization and Rasterization: Self-Supervised Learning for Sketch and Handwriting, IEEE Conf. on Computer Vision and Pattern Recognition (CVPR), 2021.

Vectorization and Rasterization: Self-Supervised Learning for Sketch and Handwriting, CVPR 2021. Ayan Kumar Bhunia, Pinaki nath Chowdhury, Yongxin Yan

Ayan Kumar Bhunia 44 Dec 12, 2022
Compositional Sketch Search

Compositional Sketch Search Official repository for ICIP 2021 Paper: Compositional Sketch Search Requirements Install and activate conda environment c

Alexander Black 8 Sep 6, 2021
A pytorch implementation of Pytorch-Sketch-RNN

Pytorch-Sketch-RNN A pytorch implementation of https://arxiv.org/abs/1704.03477 In order to draw other things than cats, you will find more drawing da

Alexis David Jacq 172 Dec 12, 2022
Open CV - Convert a picture to look like a cartoon sketch in python

Use the video https://www.youtube.com/watch?v=k7cVPGpnels for initial learning.

Sammith S Bharadwaj 3 Jan 29, 2022
Code accompanying the paper "Wasserstein GAN"

Wasserstein GAN Code accompanying the paper "Wasserstein GAN" A few notes The first time running on the LSUN dataset it can take a long time (up to an

null 3.1k Jan 1, 2023
Make your own game in a font!

Project structure. Included is a suite of tools to create font games. Tutorial: For a quick tutorial about how to make your own game go here For devel

Michael Mulet 125 Dec 4, 2022
Pip-package for trajectory benchmarking from "Be your own Benchmark: No-Reference Trajectory Metric on Registered Point Clouds", ECMR'21

Map Metrics for Trajectory Quality Map metrics toolkit provides a set of metrics to quantitatively evaluate trajectory quality via estimating consiste

Mobile Robotics Lab. at Skoltech 31 Oct 28, 2022
An open source Jetson Nano baseboard and tools to design your own.

My Jetson Nano Baseboard This basic baseboard gives the user the foundation and the flexibility to design their own baseboard for the Jetson Nano. It

NVIDIA AI IOT 57 Dec 29, 2022
Apply our monocular depth boosting to your own network!

MergeNet - Boost Your Own Depth Boost custom or edited monocular depth maps using MergeNet Input Original result After manual editing of base You can

Computational Photography Lab @ SFU 142 Dec 17, 2022
Have you ever wondered how cool it would be to have your own A.I

Have you ever wondered how cool it would be to have your own A.I. assistant Imagine how easier it would be to send emails without typing a single word, doing Wikipedia searches without opening web browsers, and performing many other daily tasks like playing music with the help of a single voice command.

Harsh Gupta 1 Nov 9, 2021
A colab notebook for training Stylegan2-ada on colab, transfer learning onto your own dataset.

Stylegan2-Ada-Google-Colab-Starter-Notebook A no thrills colab notebook for training Stylegan2-ada on colab. transfer learning onto your own dataset h

Harnick Khera 66 Dec 16, 2022
Usable Implementation of "Bootstrap Your Own Latent" self-supervised learning, from Deepmind, in Pytorch

Bootstrap Your Own Latent (BYOL), in Pytorch Practical implementation of an astoundingly simple method for self-supervised learning that achieves a ne

Phil Wang 1.4k Dec 29, 2022
Train an imgs.ai model on your own dataset

imgs.ai is a fast, dataset-agnostic, deep visual search engine for digital art history based on neural network embeddings.

Fabian Offert 5 Dec 21, 2021