Official implementation of the paper Label-Efficient Semantic Segmentation with Diffusion Models

Overview

Label-Efficient Semantic Segmentation with Diffusion Models

Official implementation of the paper Label-Efficient Semantic Segmentation with Diffusion Models

This code is based on datasetGAN and guided-diffusion.

Note: use --recurse-submodules when clone.

 

Overview

The paper investigates the representations learned by the state-of-the-art DDPMs and shows that they capture high-level semantic information valuable for downstream vision tasks. We design a simple segmentation approach that exploits these representations and outperforms the alternatives in the few-shot operating point in the context of semantic segmentation.

DDPM-based Segmentation

 

Dependencies

  • Python >= 3.7
  • Packages: see requirements.txt

 

Datasets

The evaluation is performed on 6 collected datasets with a few annotated images in the training set: Bedroom-18, FFHQ-34, Cat-15, Horse-21, CelebA-19 and ADE-Bedroom-30. The number corresponds to the number of semantic classes.

datasets.tar.gz (~47Mb)

 

DDPM

Pretrained DDPMs

The models trained on LSUN are adopted from guided-diffusion. FFHQ-256 is trained by ourselves using the same model parameters as for the LSUN models.

LSUN-Bedroom: lsun_bedroom.pt
FFHQ-256: ffhq.pt
LSUN-Cat: lsun_cat.pt
LSUN-Horse: lsun_horse.pt

Run

  1. Download the datasets:
      bash datasets/download_datasets.sh
  2. Download the DDPM checkpoint:
       bash checkpoints/ddpm/download_checkpoint.sh
  3. Check paths in experiments/ /ddpm.json
  4. Run: bash scripts/ddpm/train_interpreter.sh

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30

How to improve the performance

  1. Set input_activations=true in experiments/ /ddpm.json .
       In this case, the feature dimension is 18432.
  2. Tune for a particular task what diffusion steps and UNet blocks to use.

 

DatasetDDPM

Synthetic datasets

To download DDPM-produced synthetic datasets (50000 samples, ~7Gb):
bash synthetic-datasets/gan/download_synthetic_dataset.sh

Run | Option #1

  1. Download the synthetic dataset:
       bash synthetic-datasets/ddpm/download_synthetic_dataset.sh
  2. Check paths in experiments/ /datasetDDPM.json
  3. Run: bash scripts/datasetDDPM/train_deeplab.sh

Run | Option #2

  1. Download the datasets:
       bash datasets/download_datasets.sh

  2. Download the DDPM checkpoint:
       bash checkpoints/ddpm/download_checkpoint.sh

  3. Check paths in experiments/ /datasetDDPM.json

  4. Train an interpreter on a few DDPM-produced annotated samples:
       bash scripts/datasetDDPM/train_interpreter.sh

  5. Generate a synthetic dataset:
       bash scripts/datasetDDPM/generate_dataset.sh
        Please specify the hyperparameters in this script for the available resources.
        On 8xA100 80Gb, it takes about 12 hours to generate 10000 samples.

  6. Run: bash scripts/datasetDDPM/train_deeplab.sh
       One needs to specify the path to the generated data. See comments in the script.

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21

 

SwAV

Pretrained SwAVs

We pretrain SwAV models using the official implementation on the LSUN and FFHQ-256 datasets:

LSUN-Bedroom: lsun_bedroom.pth
FFHQ-256: ffhq.pth
LSUN-Cat: lsun_cat.pth
LSUN-Horse: lsun_horse.pth

Training setup:

Dataset epochs batch-size multi-crop num-prototypes
LSUN 200 1792 2x256 + 6x108 1000
FFHQ-256 400 2048 2x224 + 6x96 200

Run

  1. Download the datasets:
       bash datasets/download_datasets.sh
  2. Download the SwAV checkpoint:
       bash checkpoints/swav/download_checkpoint.sh
  3. Check paths in experiments/ /swav.json
  4. Run: bash scripts/swav/train_interpreter.sh

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30

 

DatasetGAN

Opposed to the official implementation, more recent StyleGAN2(-ADA) models are used.

Synthetic datasets

To download GAN-produced synthetic datasets (50000 samples):

bash synthetic-datasets/gan/download_synthetic_dataset.sh

Run

Since we almost fully adopt the official implementation, we don't provide our reimplementation here. However, one can still reproduce our results:

  1. Download the synthetic dataset:
      bash synthetic-datasets/gan/download_synthetic_dataset.sh
  2. Change paths in experiments/ /datasetDDPM.json
  3. Change paths and run: bash scripts/datasetDDPM/train_deeplab.sh

Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21

 

Results

  • Performance in terms of mean IoU:
Method Bedroom-28 FFHQ-34 Cat-15 Horse-21 CelebA-19 ADE-Bedroom-30
ALAE 20.0 ± 1.0 48.1 ± 1.3 -- -- 49.7 ± 0.7 15.0 ± 0.5
VDVAE -- 57.3 ± 1.1 -- -- 54.1 ± 1.0 --
GAN Inversion 13.9 ± 0.6 51.7 ± 0.8 21.4 ± 1.7 17.7 ± 0.4 51.5 ± 2.3 11.1 ± 0.2
GAN Encoder 22.4 ± 1.6 53.9 ± 1.3 32.0 ± 1.8 26.7 ± 0.7 53.9 ± 0.8 15.7 ± 0.3
SwAV 41.0 ± 2.3 54.7 ± 1.4 44.1 ± 2.1 51.7 ± 0.5 53.2 ± 1.0 30.3 ± 1.5
DatasetGAN 31.3 ± 2.7 57.0 ± 1.0 36.5 ± 2.3 45.4 ± 1.4 -- --
DatasetDDPM 46.9 ± 2.8 56.0 ± 0.9 45.4 ± 2.8 60.4 ± 1.2 -- --
DDPM 46.1 ± 1.9 57.0 ± 1.4 52.3 ± 3.0 63.1 ± 0.9 57.0 ± 1.0 32.3 ± 1.5

 

  • Examples of segmentation masks predicted by the DDPM-based method:
DDPM-based Segmentation

 

Cite

@misc{baranchuk2021labelefficient,
      title={Label-Efficient Semantic Segmentation with Diffusion Models}, 
      author={Dmitry Baranchuk and Ivan Rubachev and Andrey Voynov and Valentin Khrulkov and Artem Babenko},
      year={2021},
      eprint={2112.03126},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
Comments
  • What parameter should I use to implement the 64*64 diffusion model in DDPM-segmentation

    What parameter should I use to implement the 64*64 diffusion model in DDPM-segmentation

    Hey, I'm confused. I just trained a model with image size 64 and steps 4000 from improved-diffusion, but when I want to use it into segmentation , it got error and it shows that there are many tensor size can't match, can you help me to find out where should i advise in experiments/ddpm.json?

    opened by johnwalking 5
  • How can i use codes to implement the 64*64 diffusion model in DDPM-segmentation

    How can i use codes to implement the 64*64 diffusion model in DDPM-segmentation

    In my experiments, i want to use 64*64 image in DDPM-segmentation bue to limited allocation, but code shows "RuntimeError: Error(s) in loading state_dict for UNetModel "

    opened by Wanyidon 4
  • Some questions with program which has been killed

    Some questions with program which has been killed

    Thank you very much. I met a new trouble as follows:

    Total dimension 8448 scripts/ddpm/train_interpreter.sh: line 6: 169447 Killed python train_interpreter.py --exp experiments/${DATASET}/ddpm.json $MODEL_FLAGS

    I found this error in line 57 "X = X.permute(1,0,2,3).reshape(d, -1).permute(1, 0)". ("train_interpreter.py") I think my computer with memory 512G is enough to finish this work. My pytorch version is 1.8.0. I don't know if it matters.

    opened by HoJ-Onle 4
  • Question about colorize_mask

    Question about colorize_mask

    Hello, thanks for your great work and your effort on sharing this code! However, when running the code, I encountered an issue regarding the function "colorize_mask". More specifically, when I use the following code at "pixel_classifier.py",

    mask = colorize_mask(pred[0], palette)
    

    an error was reported

    image

    And it works fine when I modify it to

    mask = colorize_mask(pred, palette)
    

    So is it a bug?

    opened by JingyeChen 3
  • FID for guided diffusion model on FFHQ 256

    FID for guided diffusion model on FFHQ 256

    Dear ddpm-segmentation team,

    Thank you for sharing this great work. I really enjoy it.

    Could you tell me the FID of the guided diffusion model you pretrained on FFHQ 256?

    Thank you for your help.

    Best Wishes,

    Zongze

    opened by betterze 2
  • Some questions about mpi4py

    Some questions about mpi4py

    Hi. Thanks for your work. I had some problems when I was following your work. In fact, I can't install the "mpi4py". How can I run the code without this requirement?

    opened by HoJ-Onle 2
  • Question about transform pipeline.

    Question about transform pipeline.

    Hi, Thanks for sharing your codes.

    I'm confused by the transform function of input images: Why are the whole images transformed by lambda x: x * 2 - 1.(Here) I've also tried to discard this transform, and the model goes worse than before. Could you please enlighten me on this question?

    Thank you!

    opened by yhuang1997 1
  • What's the scheme for datasets splitting over 5 independent runs?

    What's the scheme for datasets splitting over 5 independent runs?

    Hi, The work is remarkable, and thank you for sharing your codes! I wanna reproduce and cite your work but got a question about the scheme for dataset splitting. As said in your paper, all results are from five different dataset splitting. But I found that the downloaded dataset has been split into train & test sets in advance. So, can I combine train and test and then split them randomly? Is it the right way to use these datasets? I would appreciate it if you could teach me about these datasets. Thank you!

    opened by yhuang1997 1
  • How to reproduce the Figure 4. in paper

    How to reproduce the Figure 4. in paper

    I have followed the paper to k-means clusters the feature from decoder blocks {6, 8, 10, 12} but l failed. The reproduced result is totally mosaic. And l want to ask whether the code of this part can released. image

    opened by StonERMax 1
  • How can I train a DDPM model on my own dataset?

    How can I train a DDPM model on my own dataset?

    Hi! Thank you very much for sharing the code, but I'm a newbie in Diffusion Model. I know that the FFHQ-256 is trained by yourself. I would like to try something on other dataset, but I don't know how to do. Can you please share code about how to train a pretrained model on my own dataset ( like CIFAR10 )?

    Thank you very much!

    opened by Liareee 0
  • Pre-trained pixel-classifier?

    Pre-trained pixel-classifier?

    Hi,

    Is there any chance to get the pre-trained pixel classifier you trained on your data? Due to high extensive memory usage, I am not able to train the model and just want to evaluate the model. So need the pre-trained MLPs for pixel classification. That would be amazing if you can share them. Thanks.

    opened by HOMGH 2
  • Memory issue?

    Memory issue?

    Hi, Thanks for sharing your code. I got "scripts/ddpm/train_interpreter.sh: line 6: 3842 Killed python train_interpreter.py --exp experiments/${DATASET}/ddpm.json $MODEL_FLAGS" error.

    I have ~65G RAM available on my Ubuntu. Considering your note that "it requires ~210Gb for 50 training images of 256x256." Does it mean that it's not feasible to train the model on my system? How about evaluation? Thanks in advance.

    opened by HOMGH 1
Owner
Yandex Research
Yandex Research
Official implementation of "Open-set Label Noise Can Improve Robustness Against Inherent Label Noise" (NeurIPS 2021)

Open-set Label Noise Can Improve Robustness Against Inherent Label Noise NeurIPS 2021: This repository is the official implementation of ODNL. Require

Hongxin Wei 12 Dec 7, 2022
A PyTorch implementation of ICLR 2022 Oral paper PiCO: Contrastive Label Disambiguation for Partial Label Learning

PiCO: Contrastive Label Disambiguation for Partial Label Learning This is a PyTorch implementation of ICLR 2022 Oral paper PiCO; also see our Project

王皓波 83 May 11, 2022
Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data Christoph Reich, Tim Prangemeier, Özdemir Cetin & Heinz Koeppl | Pr

Christoph Reich 23 Sep 21, 2022
Official Pytorch Implementation of: "Semantic Diversity Learning for Zero-Shot Multi-label Classification"(2021) paper

Semantic Diversity Learning for Zero-Shot Multi-label Classification Paper Official PyTorch Implementation Avi Ben-Cohen, Nadav Zamir, Emanuel Ben Bar

null 28 Aug 29, 2022
Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation)

Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation) Download Synthia dataset The model uses

null 32 Sep 21, 2022
Official implementation of "SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers"

SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers Figure 1: Performance of SegFormer-B0 to SegFormer-B5. Project page

NVIDIA Research Projects 1.4k Dec 31, 2022
Official implementation of AAAI-21 paper "Label Confusion Learning to Enhance Text Classification Models"

Description: This is the official implementation of our AAAI-21 accepted paper Label Confusion Learning to Enhance Text Classification Models. The str

null 101 Nov 25, 2022
Prototypical Pseudo Label Denoising and Target Structure Learning for Domain Adaptive Semantic Segmentation (CVPR 2021)

Prototypical Pseudo Label Denoising and Target Structure Learning for Domain Adaptive Semantic Segmentation (CVPR 2021, official Pytorch implementatio

Microsoft 247 Dec 25, 2022
Official PyTorch implementation for FastDPM, a fast sampling algorithm for diffusion probabilistic models

Official PyTorch implementation for "On Fast Sampling of Diffusion Probabilistic Models". FastDPM generation on CIFAR-10, CelebA, and LSUN datasets. S

Zhifeng Kong 68 Dec 26, 2022
Label Mask for Multi-label Classification

LM-MLC 一种基于完型填空的多标签分类算法 1 前言 本文主要介绍本人在全球人工智能技术创新大赛【赛道一】设计的一种基于完型填空(模板)的多标签分类算法:LM-MLC,该算法拟合能力很强能感知标签关联性,在多个数据集上测试表明该算法与主流算法无显著性差异,在该比赛数据集上的dev效果很好,但是由

null 52 Nov 20, 2022
PyTorch Implementation of DiffGAN-TTS: High-Fidelity and Efficient Text-to-Speech with Denoising Diffusion GANs

DiffGAN-TTS - PyTorch Implementation PyTorch implementation of DiffGAN-TTS: High

Keon Lee 157 Jan 1, 2023
Code for the ICCV 2021 Workshop paper: A Unified Efficient Pyramid Transformer for Semantic Segmentation.

Unified-EPT Code for the ICCV 2021 Workshop paper: A Unified Efficient Pyramid Transformer for Semantic Segmentation. Installation Linux, CUDA>=10.0,

null 29 Aug 23, 2022
Official codebase for running the small, filtered-data GLIDE model from GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models.

GLIDE This is the official codebase for running the small, filtered-data GLIDE model from GLIDE: Towards Photorealistic Image Generation and Editing w

OpenAI 2.9k Jan 4, 2023
Pytorch implementation of SenFormer: Efficient Self-Ensemble Framework for Semantic Segmentation

SenFormer: Efficient Self-Ensemble Framework for Semantic Segmentation Efficient Self-Ensemble Framework for Semantic Segmentation by Walid Bousselham

null 61 Dec 26, 2022
PixelPick This is an official implementation of the paper "All you need are a few pixels: semantic segmentation with PixelPick."

PixelPick This is an official implementation of the paper "All you need are a few pixels: semantic segmentation with PixelPick." [Project page] [Paper

Gyungin Shin 59 Sep 25, 2022
Official implementation of our paper "LLA: Loss-aware Label Assignment for Dense Pedestrian Detection" in Pytorch.

LLA: Loss-aware Label Assignment for Dense Pedestrian Detection This project provides an implementation for "LLA: Loss-aware Label Assignment for Dens

null 35 Dec 6, 2022
Official implementation of paper "Query2Label: A Simple Transformer Way to Multi-Label Classification".

Introdunction This is the official implementation of the paper "Query2Label: A Simple Transformer Way to Multi-Label Classification". Abstract This pa

Shilong Liu 274 Dec 28, 2022
Official implementation for the paper: "Multi-label Classification with Partial Annotations using Class-aware Selective Loss"

Multi-label Classification with Partial Annotations using Class-aware Selective Loss Paper | Pretrained models Official PyTorch Implementation Emanuel

null 99 Dec 27, 2022
The official implementation of NeurIPS 2021 paper: Finding Optimal Tangent Points for Reducing Distortions of Hard-label Attacks

The official implementation of NeurIPS 2021 paper: Finding Optimal Tangent Points for Reducing Distortions of Hard-label Attacks

machen 11 Nov 27, 2022