StyleMapGAN - Official PyTorch Implementation

Overview

StyleMapGAN - Official PyTorch Implementation

StyleMapGAN: Exploiting Spatial Dimensions of Latent in GAN for Real-time Image Editing
Hyunsu Kim, Yunjey Choi, Junho Kim, Sungjoo Yoo, Youngjung Uh
In CVPR 2021.

Paper: https://arxiv.org/abs/2104.14754
Video: https://youtu.be/qCapNyRA_Ng

Abstract: Generative adversarial networks (GANs) synthesize realistic images from random latent vectors. Although manipulating the latent vectors controls the synthesized outputs, editing real images with GANs suffers from i) time-consuming optimization for projecting real images to the latent vectors, ii) or inaccurate embedding through an encoder. We propose StyleMapGAN: the intermediate latent space has spatial dimensions, and a spatially variant modulation replaces AdaIN. It makes the embedding through an encoder more accurate than existing optimization-based methods while maintaining the properties of GANs. Experimental results demonstrate that our method significantly outperforms state-of-the-art models in various image manipulation tasks such as local editing and image interpolation. Last but not least, conventional editing methods on GANs are still valid on our StyleMapGAN. Source code is available at https://github.com/naver-ai/StyleMapGAN.

Demo

Youtube video Click the figure to watch the teaser video.

Interactive demo app Run demo in your local machine.

All test images are from CelebA-HQ, AFHQ, and LSUN.

python demo.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --dataset celeba_hq

Installation

ubuntu gcc 7.4.0 CUDA CUDA-driver cudnn7 conda Python 3.6.12 pytorch 1.4.0

Clone this repository:

git clone https://github.com/naver-ai/StyleMapGAN.git
cd StyleMapGAN/

Install the dependencies:

conda create -y -n stylemapgan python=3.6.12
conda activate stylemapgan
./install.sh

Datasets and pre-trained networks

We provide a script to download datasets used in StyleMapGAN and the corresponding pre-trained networks. The datasets and network checkpoints will be downloaded and stored in the data and expr/checkpoints directories, respectively.

CelebA-HQ. To download the CelebA-HQ dataset and parse it, run the following commands:

# Download raw images and create LMDB datasets using them
# Additional files are also downloaded for local editing
bash download.sh create-lmdb-dataset celeba_hq

# Download the pretrained network (256x256)
bash download.sh download-pretrained-network-256 celeba_hq

# Download the pretrained network (1024x1024 image / 16x16 stylemap / Light version of Generator)
bash download.sh download-pretrained-network-1024 ffhq_16x16

AFHQ. For AFHQ, change above commands from 'celeba_hq' to 'afhq'.

Train network

Implemented using DistributedDataParallel.

# CelebA-HQ
python train.py --dataset celeba_hq --train_lmdb data/celeba_hq/LMDB_train --val_lmdb data/celeba_hq/LMDB_val

# AFHQ
python train.py --dataset afhq --train_lmdb data/afhq/LMDB_train --val_lmdb data/afhq/LMDB_val

# CelebA-HQ / 1024x1024 image / 16x16 stylemap / Light version of Generator
python train.py --size 1024 --latent_spatial_size 16 --small_generator --dataset celeba_hq --train_lmdb data/celeba_hq/LMDB_train --val_lmdb data/celeba_hq/LMDB_val 

Generate images

Reconstruction Results are saved to expr/reconstruction.

# CelebA-HQ
python generate.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --mixing_type reconstruction --test_lmdb data/celeba_hq/LMDB_test

# AFHQ
python generate.py --ckpt expr/checkpoints/afhq_256_8x8.pt --mixing_type reconstruction --test_lmdb data/afhq/LMDB_test

W interpolation Results are saved to expr/w_interpolation.

# CelebA-HQ
python generate.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --mixing_type w_interpolation --test_lmdb data/celeba_hq/LMDB_test

# AFHQ
python generate.py --ckpt expr/checkpoints/afhq_256_8x8.pt --mixing_type w_interpolation --test_lmdb data/afhq/LMDB_test

Local editing Results are saved to expr/local_editing. We pair images using a target semantic mask similarity. If you want to see details, please follow preprocessor/README.md.

# Using GroundTruth(GT) segmentation masks for CelebA-HQ dataset.
python generate.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --mixing_type local_editing --test_lmdb data/celeba_hq/LMDB_test --local_editing_part nose

# Using half-and-half masks for AFHQ dataset.
python generate.py --ckpt expr/checkpoints/afhq_256_8x8.pt --mixing_type local_editing --test_lmdb data/afhq/LMDB_test

Unaligned transplantation Results are saved to expr/transplantation. It shows local transplantations examples of AFHQ. We recommend the demo code instead of this.

python generate.py --ckpt expr/checkpoints/afhq_256_8x8.pt --mixing_type transplantation --test_lmdb data/afhq/LMDB_test

Random Generation Results are saved to expr/random_generation. It shows random generation examples.

python generate.py --mixing_type random_generation --ckpt expr/checkpoints/celeba_hq_256_8x8.pt

Style Mixing Results are saved to expr/stylemixing. It shows style mixing examples.

python generate.py --mixing_type stylemixing --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --test_lmdb data/celeba_hq/LMDB_test

Semantic Manipulation Results are saved to expr/semantic_manipulation. It shows local semantic manipulation examples.

python semantic_manipulation.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --LMDB data/celeba_hq/LMDB --svm_train_iter 10000

Metrics

  • Reconstruction: LPIPS, MSE
  • W interpolation: FIDlerp
  • Generation: FID
  • Local editing: MSEsrc, MSEref, Detectability (Refer to CNNDetection)

If you want to see details, please follow metrics/README.md.

License

The source code, pre-trained models, and dataset are available under Creative Commons BY-NC 4.0 license by NAVER Corporation. You can use, copy, tranform and build upon the material for non-commercial purposes as long as you give appropriate credit by citing our paper, and indicate if changes were made.

For business inquiries, please contact [email protected].
For technical and other inquires, please contact [email protected].

Citation

If you find this work useful for your research, please cite our paper:

@inproceedings{kim2021stylemapgan,
  title={Exploiting Spatial Dimensions of Latent in GAN for Real-time Image Editing},
  author={Kim, Hyunsu and Choi, Yunjey and Kim, Junho and Yoo, Sungjoo and Uh, Youngjung},
  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
  year={2021}
}

Related Projects

Model code starts from StyleGAN2 PyTorch unofficial code, which refers to StyleGAN2 official code. LPIPS, FID, and CNNDetection codes are used for evaluation. In semantic manipulation, we used StyleGAN pretrained network to get positive and negative samples by ranking. The demo code starts from Neural-Collage.

Comments
  • UnpicklingError: invalid load key, '<'.

    UnpicklingError: invalid load key, '<'.

    Hello authors,

    I have an error when using the pretrained checkpoints when torch.load(args.ckpt) is run in generate.py (or any piece of code with the function). I tried re-downloading the model as proposed in this issue. The only change I made to accommodate my CUDA 11 version was to instal pytorch with conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch. Thanks for you attention, it is much appreciated.

    opened by josejhlee 9
  • Questions about metrics for reconstruction

    Questions about metrics for reconstruction

    Hello, thanks for sharing your code. I'm very interested in your work and I got a question when I run the metrics for reconstruction. The output of metrics.reconstruction are 0.015(mse) and 0.214(lpips). The output seems much better than the results reported in Table 3 in the paper (mse=0.024, lpips=0.242). I would appreciate it if you could give me an explanation. image

    opened by Annonymous-repos 5
  • Download.sh updated link points to possibly corrupted models

    Download.sh updated link points to possibly corrupted models

    I downloaded the pre-trained networks using the updated download.sh files, and pytorch is unable to load them, giving the following error: _pickle.UnpicklingError: invalid load key, ‘<‘

    I suspect that the pre-trained network on the google drive is corrupted. Could you check whether it is the case?

    opened by BigBenPost 2
  • Fail to open demo

    Fail to open demo

    Thanks for the excellent work!

    I am failed to run the demo on my local machine. The app does not produce the wrong information and the website seems can not load the demo images. Can you help me to fix it?

    I can run all the commands except the demo one such as the unaligned transplantation command.

    screenshot0 screenshot1

    opened by berylsheep-up 2
  • GPU memory shortage problem when loading weights from checkpoints

    GPU memory shortage problem when loading weights from checkpoints

    Hello, I would first like to thank you for sharing your work.

    I am having problem on loading weights from checkpoints(i.e on continuing halted training)

    I am training StyleMapGAN on custom dataset(~200K images in the training dataset, 1024*1024 resoulution), and I am currently using 3 TitanRTX GPUs. I am using latent_spatial_size=16 considering input image resolution and GPU memory. On training with such configuration, batch 2 is allocated per GPU using ~21 GiB memory.

    There is no problem on training from scratch. I have not tried using pretrained weights trained on FFHQ or CelebA because my data is quite different from human faces. Moreover, as I have succeeded on generating images from generate.py, I think weights were saved in proper way.

    However, memory allocation problem occurs every time I load custom weights to continue training. I assumed extra memory may be required on loading weights, so I tried using smaller batch size (batch 2 per GPU->batch 1 per GPU), but same memory shortage problem occurs.

    To summarize, I cannot load weights to continue training, whereas training from scratch or loading weights to generate images are working well. Thereafter, I would like to ask following questions.

    1. Had any of the authors experienced with similar problems?
    2. Would there be any possible solutions to my problem?

    I would be grateful if you take a look into my question. Thank you!

    opened by junikkoma 2
  • How many iterations do you use to train the model

    How many iterations do you use to train the model

    Thank you for publishing the code. Your work is very impressive. I wonder how many iterations do you use to train the model?I noticed that the default is 1,400,000. Is this value for all your training datasets? 1400000 training iterations a bit too long, I am curious if there is a small number of training iteration.

    opened by zhang-lingyun 2
  • can i get a new pretrained network file?

    can i get a new pretrained network file?

    there is how to get pretrained network

    # Download raw images and create LMDB datasets using them
    # Additional files are also downloaded for local editing
    bash download.sh create-lmdb-dataset celeba_hq
    
    # Download the pretrained network (256x256) 
    bash download.sh download-pretrained-network-256 celeba_hq # 20M-image-trained models
    bash download.sh download-pretrained-network-256 celeba_hq_5M # 5M-image-trained models used in our paper for comparison with other baselines and for ablation studies.
    
    # Download the pretrained network (1024x1024 image / 16x16 stylemap / Light version of Generator)
    bash download.sh download-pretrained-network-1024 ffhq_16x16
    

    but with these networks, it doesn't work

    File "demo.py", line 192, in <module> ckpt = torch.load(args.ckpt) File "/root/anaconda3/envs/stylemap/lib/python3.6/site-packages/torch/serialization.py", line 608, in load return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args) File "/root/anaconda3/envs/stylemap/lib/python3.6/site-packages/torch/serialization.py", line 777, in _legacy_load magic_number = pickle_module.load(f, **pickle_load_args) _pickle.UnpicklingError: invalid load key, '<'.

    and i found that after run train.py, i got 000000.pt (didn't trained) and with this 000000.pt, demo.py works well,(but output image is noisy image)

    so is there any way get new pretrained network?

    i tried on (pytorch= 1.4.0, 1.10) and (remote server-docker , colab)

    opened by ThisisLandu 1
  • How about the generator and the encoder being trained together?

    How about the generator and the encoder being trained together?

    Hi,Thanks for your great work. As you say in the paper, the generator and the encoder are trained jointly not separately, how about trained together as a Unified framework without fixing weights.

    opened by diaodeyi 0
  • metrics in table 3

    metrics in table 3

    how can u get MSE, LPIPS of image2stylegan, structured noise and stylegan2 in table3, i mean, structured noise didn't provide the code about how to get image reconstruction, and image2stylegan didn't provide codes(No official code found), and what about stylegan2, can u tell me please?

    opened by JNash123 0
  • Encoding quality of real images

    Encoding quality of real images

    Hello, thank you for interesting work! used python demo.py --ckpt expr/checkpoints/celeba_hq_8x8_20M_revised.pt --dataset celeba_hq I try to evaluate quality of real photos embedding.

    It works impressively good on provided by default images: image

    however drastically worse on photos from internet: image image image (photos are aligned by face landmarks and cropped to 1024x1024 by https://github.com/ZPdesu/Barbershop/blob/main/align_face.py)

    what can be the reason and how do I fix this?

    source images: barack curly2 durov

    opened by gordinmitya 3
  • How to do StyleMixing on custom dataset

    How to do StyleMixing on custom dataset

    Hello. Thanks for this code. I have trained the model on my own dataset that is different from the face dataset. The image reconstruction, random generation works well. However, If I do the style mixing their is a hard coded pkl file required that is related to the celeba_hq dataset (data/celeba_hq/local_editing/celeba_hq_test_GT_sorted_pair.pkl).

    Just for the sake of running stylemixing I downloaded the dataset and ran the code, however, the results are not good.

    Could you please share your thoughts on how to apply stylemixing on a custom dataset?

    opened by ammar-deep 0
  • Different sizes between downloaded lmdb data and generated data with prepare_data.py

    Different sizes between downloaded lmdb data and generated data with prepare_data.py

    Hi,

    I use the provided prepare_data.py to transfer the downloaded raw images to the corresponding lmdb data with 256x256 size. But I find that the lmdb data has different sizes compared to the downloaded lmdb data (LMDB_train/test/val). I also tried other lmdb sizes such as 128, 512 and 1024 but no image size can match the downloaded lmdb data.

    May I figure out the reason?

    Thanks.

    opened by rshaojimmy 0
  • How to calculate or obtain GT_labels and LMDB_test_mask for other images in CelebA_HQ

    How to calculate or obtain GT_labels and LMDB_test_mask for other images in CelebA_HQ

    Hi, thanks for your excellent work!

    May I know how to calculate or obtain GT_labels and LMDB_test_mask for other images in CelebA_HQ as this repo just provides the processed masks for downloading via download.sh.

    Thanks!

    opened by rshaojimmy 1
Owner
NAVER AI
Official account of NAVER AI, Korea No.1 Industrial AI Research Group
NAVER AI
Official PyTorch implementation for paper Context Matters: Graph-based Self-supervised Representation Learning for Medical Images

Context Matters: Graph-based Self-supervised Representation Learning for Medical Images Official PyTorch implementation for paper Context Matters: Gra

null 49 Nov 23, 2022
StyleGAN2-ADA - Official PyTorch implementation

Abstract: Training generative adversarial networks (GAN) using too little data typically leads to discriminator overfitting, causing training to diverge. We propose an adaptive discriminator augmentation mechanism that significantly stabilizes training in limited data regimes.

NVIDIA Research Projects 3.2k Dec 30, 2022
Official PyTorch implementation of Joint Object Detection and Multi-Object Tracking with Graph Neural Networks

This is the official PyTorch implementation of our paper: "Joint Object Detection and Multi-Object Tracking with Graph Neural Networks". Our project website and video demos are here.

Richard Wang 443 Dec 6, 2022
Official pytorch implementation of paper "Image-to-image Translation via Hierarchical Style Disentanglement".

HiSD: Image-to-image Translation via Hierarchical Style Disentanglement Official pytorch implementation of paper "Image-to-image Translation

null 364 Dec 14, 2022
Official pytorch implementation of paper "Inception Convolution with Efficient Dilation Search" (CVPR 2021 Oral).

IC-Conv This repository is an official implementation of the paper Inception Convolution with Efficient Dilation Search. Getting Started Download Imag

Jie Liu 111 Dec 31, 2022
Official PyTorch Implementation of Unsupervised Learning of Scene Flow Estimation Fusing with Local Rigidity

UnRigidFlow This is the official PyTorch implementation of UnRigidFlow (IJCAI2019). Here are two sample results (~10MB gif for each) of our unsupervis

Liang Liu 28 Nov 16, 2022
Official implementation of our paper "LLA: Loss-aware Label Assignment for Dense Pedestrian Detection" in Pytorch.

LLA: Loss-aware Label Assignment for Dense Pedestrian Detection This project provides an implementation for "LLA: Loss-aware Label Assignment for Dens

null 35 Dec 6, 2022
An official implementation of "SFNet: Learning Object-aware Semantic Correspondence" (CVPR 2019, TPAMI 2020) in PyTorch.

PyTorch implementation of SFNet This is the implementation of the paper "SFNet: Learning Object-aware Semantic Correspondence". For more information,

CV Lab @ Yonsei University 87 Dec 30, 2022
Old Photo Restoration (Official PyTorch Implementation)

Bringing Old Photo Back to Life (CVPR 2020 oral)

Microsoft 11.3k Dec 30, 2022
Official PyTorch implementation of Spatial Dependency Networks.

Spatial Dependency Networks: Neural Layers for Improved Generative Image Modeling Đorđe Miladinović   Aleksandar Stanić   Stefan Bauer   Jürgen Schmid

Djordje Miladinovic 34 Jan 19, 2022
Official implementation of our CVPR2021 paper "OTA: Optimal Transport Assignment for Object Detection" in Pytorch.

OTA: Optimal Transport Assignment for Object Detection This project provides an implementation for our CVPR2021 paper "OTA: Optimal Transport Assignme

null 217 Jan 3, 2023
This is the official PyTorch implementation of the paper "TransFG: A Transformer Architecture for Fine-grained Recognition" (Ju He, Jie-Neng Chen, Shuai Liu, Adam Kortylewski, Cheng Yang, Yutong Bai, Changhu Wang, Alan Yuille).

TransFG: A Transformer Architecture for Fine-grained Recognition Official PyTorch code for the paper: TransFG: A Transformer Architecture for Fine-gra

Ju He 307 Jan 3, 2023
StyleGAN2-ADA - Official PyTorch implementation

Need Help? If you’re new to StyleGAN2-ADA and looking to get started, please check out this video series from a course Lia Coleman and I taught in Oct

Derrick Schultz 217 Jan 4, 2023
Official PyTorch implementation of "ArtFlow: Unbiased Image Style Transfer via Reversible Neural Flows"

ArtFlow Official PyTorch implementation of the paper: ArtFlow: Unbiased Image Style Transfer via Reversible Neural Flows Jie An*, Siyu Huang*, Yibing

null 123 Dec 27, 2022
Official PyTorch implementation of RobustNet (CVPR 2021 Oral)

RobustNet (CVPR 2021 Oral): Official Project Webpage Codes and pretrained models will be released soon. This repository provides the official PyTorch

Sungha Choi 173 Dec 21, 2022
Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.

PyTorch Implementation of Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers 1 Using Colab Please notic

Hila Chefer 489 Jan 7, 2023
[PyTorch] Official implementation of CVPR2021 paper "PointDSC: Robust Point Cloud Registration using Deep Spatial Consistency". https://arxiv.org/abs/2103.05465

PointDSC repository PyTorch implementation of PointDSC for CVPR'2021 paper "PointDSC: Robust Point Cloud Registration using Deep Spatial Consistency",

null 153 Dec 14, 2022
Official PyTorch implementation of MX-Font (Multiple Heads are Better than One: Few-shot Font Generation with Multiple Localized Experts)

Introduction Pytorch implementation of Multiple Heads are Better than One: Few-shot Font Generation with Multiple Localized Expert. | paper Song Park1

Clova AI Research 97 Dec 23, 2022
Official Pytorch implementation of 'GOCor: Bringing Globally Optimized Correspondence Volumes into Your Neural Network' (NeurIPS 2020)

Official implementation of GOCor This is the official implementation of our paper : GOCor: Bringing Globally Optimized Correspondence Volumes into You

Prune Truong 71 Nov 18, 2022