Labels4Free: Unsupervised Segmentation using StyleGAN

Overview

Labels4Free: Unsupervised Segmentation using StyleGAN

ICCV 2021

image Figure: Some segmentation masks predicted by Labels4Free Framework on real and synthetic images

We propose an unsupervised segmentation framework for StyleGAN generated objects. We build on two main observations. First, the features generated by StyleGAN hold valuable information that can be utilized towards training segmentation networks. Second, the foreground and background can often be treated to be largely independent and be swapped across images to produce plausible composited images. For our solution, we propose to augment the Style-GAN2 generator architecture with a segmentation branch and to split the generator into a foreground and background network. This enables us to generate soft segmentation masks for the foreground object in an unsupervised fashion. On multiple object classes, we report comparable results against state-of-the-art supervised segmentation networks, while against the best unsupervised segmentation approach we demonstrate a clear improvement, both in qualitative and quantitative metrics.

Labels4Free: Unsupervised Segmentation Using StyleGAN (ICCV 2021)
Rameen Abdal, Peihao Zhu, Niloy Mitra, Peter Wonka
KAUST, Adobe Research

[Paper] [Project Page] [Video]

Installation

Clone this repo.

git clone https://github.com/RameenAbdal/Labels4Free.git
cd Labels4Free/

This repo is based on the Pytorch implementation of StyleGAN2 (rosinality/stylegan2-pytorch). Refer to this repo for setting up the environment, preparation of LMDB datasets and downloading pretrained weights of the models.

Download the pretrained weights of Alpha Networks here

Training the models

The models were trained on 4 RTX 2080 (24 GB) GPUs. In order to train the models using the settings in the paper use the following commands for each dataset.

Checkpoints and samples are saved in ./checkpoint and ./sample folders.

FFHQ dataset

python -m torch.distributed.launch --nproc_per_node=4 train.py --size 1024 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [FFHQ_CONFIG-F_CHECKPOINT]--loss_multiplier 1.2 --iter 1200 --trunc 1.0 --lr 0.0002 --reproduce_model

LSUN-Horse dataset

python -m torch.distributed.launch --nproc_per_node=4 train.py --size 256 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [LSUN_HORSE_CONFIG-F_CHECKPOINT] --loss_multiplier 3 --iter 500 --trunc 1.0 --lr 0.0002 --reproduce_model

LSUN-Cat dataset

python -m torch.distributed.launch --nproc_per_node=4 train.py --size 256 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [LSUN_CAT_CONFIG-F_CHECKPOINT]  --loss_multiplier 3 --iter 900 --trunc 0.5 --lr 0.0002 --reproduce_model

LSUN-Car dataset

python train.py --size 512 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [LSUN_CAR_CONFIG-F_CHECKPOINT] --loss_multiplier 10 --iter 50 --trunc 0.3 --lr 0.002 --sat_weight 1.0 --model_save_freq 25 --reproduce_model --use_disc

In order to train your own models using different settings e.g on a single GPU, using different samples, iterations etc. use the following commands.

FFHQ dataset

python train.py --size 1024 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [FFHQ_CONFIG-F_CHECKPOINT] --loss_multiplier 1.2 --iter 2000 --trunc 1.0 --lr 0.0002 --bg_coverage_wt 3 --bg_coverage_value 0.4

LSUN-Horse dataset

python train.py --size 256 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [LSUN_HORSE_CONFIG-F_CHECKPOINT] --loss_multiplier 3 --iter 2000 --trunc 1.0 --lr 0.0002 --bg_coverage_wt 6 --bg_coverage_value 0.6

LSUN-Cat dataset

python train.py --size 256 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [LSUN_CAT_CONFIG-F_CHECKPOINT] --loss_multiplier 3 --iter 2000 --trunc 0.5 --lr 0.0002 --bg_coverage_wt 4 --bg_coverage_value 0.35

LSUN-Car dataset

python train.py --size 512 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [LSUN_CAR_CONFIG-F_CHECKPOINT] --loss_multiplier 20 --iter 750 --trunc 0.3 --lr 0.0008 --sat_weight 0.1 --bg_coverage_wt 40 --bg_coverage_value 0.75 --model_save_freq 50

Sample from the pretrained model

Samples are saved in ./test_sample folder.

python test_sample.py --size [SIZE] --batch 2 --n_sample 100 --ckpt_bg_extractor [ALPHANETWORK_MODEL] --ckpt_generator [GENERATOR_MODEL] --th 0.9

Results on Custom dataset

Folder: Custom dataset, predicted and ground truth masks.

python test_customdata.py --path_gt [GT_Folder] --path_pred [PRED_FOLDER]

Citation

@InProceedings{Abdal_2021_ICCV,
    author    = {Abdal, Rameen and Zhu, Peihao and Mitra, Niloy J. and Wonka, Peter},
    title     = {Labels4Free: Unsupervised Segmentation Using StyleGAN},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {13970-13979}
}

Acknowledgments

This implementation builds upon the Pytorch implementation of StyleGAN2 (rosinality/stylegan2-pytorch). This work was supported by Adobe Research and KAUST Office of Sponsored Research (OSR).

Comments
  • covert_weights

    covert_weights

    When I covert weights from .pkl to .pt It seems that it doesn't work and I use ctrl+c to end the process it turned out that the process was sleeping (pytorch) PS D:\FYP\Labels4Free-main> python convert_weight.py --repo ~/pt stylegan2-cat-config-f.pkl Traceback (most recent call last): File "convert_weight.py", line 11, in from model_new import Generator, Discriminator File "D:\FYP\Labels4Free-main\model_new.py", line 12, in from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d File "D:\FYP\Labels4Free-main\op_init_.py", line 2, in File "D:\FYP\Labels4Free-main\op\upfirdn2d.py", line 10, in upfirdn2d_op = load( File "C:\Users\wu.conda\envs\pytorch\lib\site-packages\torch\utils\cpp_extension.py", line 1124, in load return _jit_compile( File "C:\Users\wu.conda\envs\pytorch\lib\site-packages\torch\utils\cpp_extension.py", line 1351, in _jit_compile baton.wait() File "C:\Users\wu.conda\envs\pytorch\lib\site-packages\torch\utils\file_baton.py", line 42, in wait time.sleep(self.wait_seconds) KeyboardInterrupt

    opened by whyydsforever 1
  • Flower Segmentation

    Flower Segmentation

    기존 코드 behavior

    1. car dataset에 너무 의존적 → 데이터셋 불러오는 부분이 LSUN-Car가 저장되어 있는 방식인 lmdb dataset에만 의존되어 있음.
    2. model 구조에서 generated output size가 512가 아닌 경우 (flower 같은 경우에는 256으로 학습이 된 상태), channel multiplier를 1로 설정해줬음에도 불구하고, shape mismatch error가 발생.

    해결

    1. dataset 이름과 lmdb를 사용할 것인지에 따라서 lmdb를 사용하지 않게 되면, TestDataset이라는 stylegan 학습을 위한 torch dataset을 불러와서 해당 데이터셋 클래스를 사용하도록 함.
    2. 일단은 하드코딩으로 에러가 나는 layer에서의 채널 차원을 128에서 끝나는 게 아니라 64에서 끝나도록 수정해줌.

    보완할 점

    2번 같은 경우에는 정확히 어떤 문제가 있는 건지 더 파악을 해서 좀 더 유연하게 인자 대응을 할 수 있도록 코드 수정을 거쳐야될 듯함.

    opened by YoojLee 0
  • when I train with FFHQ, alpha mask is not generated

    when I train with FFHQ, alpha mask is not generated

    I would like to train with FFHQ dataset, but there are some problems. when I train with FFHQ, the alpha mask is not generated.

    The mask is all white.

    How can I fix it?

    opened by yejees 0
  • Car backgrounds not matching results from the paper

    Car backgrounds not matching results from the paper

    Hi, First of all, thanks for a great work! I've been trying to replicate results for the car dataset, but I found the background outputs to be different to the paper's results: the images look like random textures more than actual backgrounds (c.f. attached images) Any idea why? 000000_background 000000_original

    Cheers

    Ben

    opened by BenjBarral 1
  • If I can use this project for grayscale images

    If I can use this project for grayscale images

    What modifications should be done for generate grayscale images? File "C:\Users\wu.conda\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 1482, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for Generator: size mismatch for to_rgb1.bias: copying a param with shape torch.Size([1, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 1, 1]). size mismatch for to_rgb1.conv.weight: copying a param with shape torch.Size([1, 1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 512, 1, 1]). size mismatch for to_rgbs.0.bias: copying a param with shape torch.Size([1, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 1, 1]). size mismatch for to_rgbs.0.conv.weight: copying a param with shape torch.Size([1, 1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 512, 1, 1]). size mismatch for to_rgbs.1.bias: copying a param with shape torch.Size([1, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 1, 1]). size mismatch for to_rgbs.1.conv.weight: copying a param with shape torch.Size([1, 1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 512, 1, 1]). size mismatch for to_rgbs.2.bias: copying a param with shape torch.Size([1, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 1, 1]). size mismatch for to_rgbs.2.conv.weight: copying a param with shape torch.Size([1, 1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 512, 1, 1]). size mismatch for to_rgbs.3.bias: copying a param with shape torch.Size([1, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 1, 1]). size mismatch for to_rgbs.3.conv.weight: copying a param with shape torch.Size([1, 1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 512, 1, 1]). size mismatch for to_rgbs.4.bias: copying a param with shape torch.Size([1, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 1, 1]). size mismatch for to_rgbs.4.conv.weight: copying a param with shape torch.Size([1, 1, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 256, 1, 1]). size mismatch for to_rgbs.5.bias: copying a param with shape torch.Size([1, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 1, 1]). size mismatch for to_rgbs.5.conv.weight: copying a param with shape torch.Size([1, 1, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 3, 128, 1, 1]).

    opened by whyydsforever 1
Owner
PhD @ KAUST
null
[SIGGRAPH'22] StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets

[Project] [PDF] This repository contains code for our SIGGRAPH'22 paper "StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets" by Axel Sauer, Katja

null 742 Jan 4, 2023
Non-Official Pytorch implementation of "Face Identity Disentanglement via Latent Space Mapping" https://arxiv.org/abs/2005.07728 Using StyleGAN2 instead of StyleGAN

Face Identity Disentanglement via Latent Space Mapping - Implement in pytorch with StyleGAN 2 Description Pytorch implementation of the paper Face Ide

Daniel Roich 58 Dec 24, 2022
StyleSpace Analysis: Disentangled Controls for StyleGAN Image Generation

StyleSpace Analysis: Disentangled Controls for StyleGAN Image Generation Demo video: CVPR 2021 Oral: Single Channel Manipulation: Localized or attribu

Zongze Wu 267 Dec 30, 2022
Official Implementation for "ReStyle: A Residual-Based StyleGAN Encoder via Iterative Refinement" https://arxiv.org/abs/2104.02699

ReStyle: A Residual-Based StyleGAN Encoder via Iterative Refinement Recently, the power of unconditional image synthesis has significantly advanced th

null 967 Jan 4, 2023
VOGUE: Try-On by StyleGAN Interpolation Optimization

VOGUE is a StyleGAN interpolation optimization algorithm for photo-realistic try-on. Top: shirt try-on automatically synthesized by our method in two different examples.

Wei ZHANG 66 Dec 9, 2022
Implementation of StyleSpace Analysis: Disentangled Controls for StyleGAN Image Generation in PyTorch

StyleSpace Analysis: Disentangled Controls for StyleGAN Image Generation Implementation of StyleSpace Analysis: Disentangled Controls for StyleGAN Ima

Xuanchi Ren 86 Dec 7, 2022
Official implementation of "StyleCariGAN: Caricature Generation via StyleGAN Feature Map Modulation" (SIGGRAPH 2021)

StyleCariGAN in PyTorch Official implementation of StyleCariGAN:Caricature Generation via StyleGAN Feature Map Modulation in PyTorch Requirements PyTo

PeterZhouSZ 49 Oct 31, 2022
GAN encoders in PyTorch that could match PGGAN, StyleGAN v1/v2, and BigGAN. Code also integrates the implementation of these GANs.

MTV-TSA: Adaptable GAN Encoders for Image Reconstruction via Multi-type Latent Vectors with Two-scale Attentions. This is the official code release fo

owl 37 Dec 24, 2022
Streamlit Tutorial (ex: stock price dashboard, cartoon-stylegan, vqgan-clip, stylemixing, styleclip, sefa)

Streamlit Tutorials Install pip install streamlit Run cd [directory] streamlit run app.py --server.address 0.0.0.0 --server.port [your port] # http:/

Jihye Back 30 Jan 6, 2023
Official implementation of "StyleCariGAN: Caricature Generation via StyleGAN Feature Map Modulation" (SIGGRAPH 2021)

StyleCariGAN: Caricature Generation via StyleGAN Feature Map Modulation This repository contains the official PyTorch implementation of the following

Wonjong Jang 270 Dec 30, 2022
(CVPR 2021) Lifting 2D StyleGAN for 3D-Aware Face Generation

Lifting 2D StyleGAN for 3D-Aware Face Generation Official implementation of paper "Lifting 2D StyleGAN for 3D-Aware Face Generation". Requirements You

Yichun Shi 66 Nov 29, 2022
StyleGAN - Official TensorFlow Implementation

StyleGAN — Official TensorFlow Implementation Picture: These people are not real – they were produced by our generator that allows control over differ

NVIDIA Research Projects 13.1k Jan 9, 2023
A tensorflow/keras implementation of StyleGAN to generate images of new Pokemon.

PokeGAN A tensorflow/keras implementation of StyleGAN to generate images of new Pokemon. Dataset The model has been trained on dataset that includes 8

null 19 Jul 26, 2022
Jittor 64*64 implementation of StyleGAN

StyleGanJittor (Tsinghua university computer graphics course) Overview Jittor 64

Song Shengyu 3 Jan 20, 2022
A web porting for NVlabs' StyleGAN2, to facilitate exploring all kinds characteristic of StyleGAN networks

This project is a web porting for NVlabs' StyleGAN2, to facilitate exploring all kinds characteristic of StyleGAN networks. Thanks for NVlabs' excelle

K.L. 150 Dec 15, 2022
StyleGAN-Human: A Data-Centric Odyssey of Human Generation

StyleGAN-Human: A Data-Centric Odyssey of Human Generation Abstract: Unconditional human image generation is an important task in vision and graphics,

stylegan-human 762 Jan 8, 2023
pytorch implementation of "Contrastive Multiview Coding", "Momentum Contrast for Unsupervised Visual Representation Learning", and "Unsupervised Feature Learning via Non-Parametric Instance-level Discrimination"

Unofficial implementation: MoCo: Momentum Contrast for Unsupervised Visual Representation Learning (Paper) InsDis: Unsupervised Feature Learning via N

Zhiqiang Shen 16 Nov 4, 2020
PiCIE: Unsupervised Semantic Segmentation using Invariance and Equivariance in clustering (CVPR2021)

PiCIE: Unsupervised Semantic Segmentation using Invariance and Equivariance in Clustering Jang Hyun Cho1, Utkarsh Mall2, Kavita Bala2, Bharath Harihar

Jang Hyun Cho 164 Dec 30, 2022
Unsupervised Semantic Segmentation by Contrasting Object Mask Proposals.

Unsupervised Semantic Segmentation by Contrasting Object Mask Proposals This repo contains the Pytorch implementation of our paper: Unsupervised Seman

Wouter Van Gansbeke 335 Dec 28, 2022