A simple consistency training framework for semi-supervised image semantic segmentation

Overview

PseudoSeg: Designing Pseudo Labels for Semantic Segmentation

PseudoSeg is a simple consistency training framework for semi-supervised image semantic segmentation, which has a simple and novel re-design of pseudo-labeling to generate well-calibrated structured pseudo labels for training with unlabeled or weakly-labeled data. It is implemented by Yuliang Zou (research intern) in 2020 Summer.

This is not an official Google product.

Instruction

Installation

  • Use a virtual environment
virtualenv -p python3 --system-site-packages env
source env/bin/activate
  • Install packages
pip install -r requirements.txt

Dataset

Create a dataset folder under the ROOT directory, then download the pre-created tfrecords for voc12 and coco, and extract them in dataset folder. You may also want to check the filenames for each split under data_splits folder.

Training

NOTE:

  • We train all our models using 16 V100 GPUs.
  • The ImageNet pre-trained models can be download here.
  • For VOC12, ${SPLIT} can be 2_clean, 4_clean, 8_clean, 16_clean_3 (representing 1/2, 1/4, 1/8, and 1/16 splits), NUM_ITERATIONS should be set to 30000.
  • For COCO, ${SPLIT} can be 32_all, 64_all, 128_all, 256_all, 512_all (representing 1/32, 1/64, 1/128, 1/256, and 1/512 splits), NUM_ITERATIONS should be set to 200000.

Supervised baseline

python train_sup.py \
  --logtostderr \
  --train_split="${SPLIT}" \
  --model_variant="xception_65" \
  --atrous_rates=6 \
  --atrous_rates=12 \
  --atrous_rates=18 \
  --output_stride=16 \
  --decoder_output_stride=4 \
  --train_crop_size="513,513" \
  --num_clones=16 \
  --train_batch_size=64 \
  --training_number_of_steps="${NUM_ITERATIONS}" \
  --fine_tune_batch_norm=true \
  --tf_initial_checkpoint="${INIT_FOLDER}/xception_65/model.ckpt" \
  --train_logdir="${TRAIN_LOGDIR}" \
  --dataset_dir="${DATASET}"

PseudoSeg (w/ unlabeled data)

python train_wss.py \
  --logtostderr \
  --train_split="${SPLIT}" \
  --train_split_cls="train_aug" \
  --model_variant="xception_65" \
  --atrous_rates=6 \
  --atrous_rates=12 \
  --atrous_rates=18 \
  --output_stride=16 \
  --decoder_output_stride=4 \
  --train_crop_size="513,513" \
  --num_clones=16 \
  --train_batch_size=64 \
  --training_number_of_steps="${NUM_ITERATIONS}" \
  --fine_tune_batch_norm=true \
  --tf_initial_checkpoint="${INIT_FOLDER}/xception_65/model.ckpt" \
  --train_logdir="${TRAIN_LOGDIR}" \
  --dataset_dir="${DATASET}"

PseudoSeg (w/ image-level labeled data)

python train_wss.py \
  --logtostderr \
  --train_split="${SPLIT}" \
  --train_split_cls="train_aug" \
  --model_variant="xception_65" \
  --atrous_rates=6 \
  --atrous_rates=12 \
  --atrous_rates=18 \
  --output_stride=16 \
  --decoder_output_stride=4 \
  --train_crop_size="513,513" \
  --num_clones=16 \
  --train_batch_size=64 \
  --training_number_of_steps="${NUM_ITERATIONS}" \
  --fine_tune_batch_norm=true \
  --tf_initial_checkpoint="${INIT_FOLDER}/xception_65/model.ckpt" \
  --train_logdir="${TRAIN_LOGDIR}" \
  --dataset_dir="${DATASET}" \
  --weakly=true

Evaluation

NOTE: ${EVAL_CROP_SIZE} should be 513,513 for VOC12, 641,641 for COCO.

python eval.py \
  --logtostderr \
  --eval_split="val" \
  --model_variant="xception_65" \
  --atrous_rates=6 \
  --atrous_rates=12 \
  --atrous_rates=18 \
  --output_stride=16 \
  --decoder_output_stride=4 \
  --eval_crop_size="${EVAL_CROP_SIZE}" \
  --checkpoint_dir="${TRAIN_LOGDIR}" \
  --eval_logdir="${EVAL_LOGDIR}" \
  --dataset_dir="${DATASET}" \
  --max_number_of_evaluations=1

Visualization

NOTE: ${VIS_CROP_SIZE} should be 513,513 for VOC12, 641,641 for COCO.

python vis.py \
  --logtostderr \
  --vis_split="val" \
  --model_variant="xception_65" \
  --atrous_rates=6 \
  --atrous_rates=12 \
  --atrous_rates=18 \
  --output_stride=16 \
  --decoder_output_stride=4 \
  --vis_crop_size="${VIS_CROP_SIZE}" \
  --checkpoint_dir="${CKPT}" \
  --vis_logdir="${VIS_LOGDIR}" \
  --dataset_dir="${PASCAL_DATASET}" \
  --also_save_raw_predictions=true

Citation

If you use this work for your research, please cite our paper.

@article{zou2020pseudoseg,
  title={PseudoSeg: Designing Pseudo Labels for Semantic Segmentation},
  author={Zou, Yuliang and Zhang, Zizhao and Zhang, Han and Li, Chun-Liang and Bian, Xiao and Huang, Jia-Bin and Pfister, Tomas},
  journal={International Conference on Learning Representations (ICLR)},
  year={2021}
}
Comments
  • Question about Tab.5 which compares different pseudo labeling strategies

    Question about Tab.5 which compares different pseudo labeling strategies

    Thanks for your impressive work!

    I notice that in Tab.5, different pseudo labeling strategies are compared. Whether the numbers listed here are mIOU on validation set or mIOU of pseudo masks on unlabeled images from the training set?

    It seems the latter one, if so, are these unlabeled images from the VOC original training set(1,464 images) or the augmented SBD images?

    opened by LiheYoung 6
  • Concern on the details of the comparison results in Table-2

    Concern on the details of the comparison results in Table-2

    Really nice paper!

    We carefully read your work and find the experimental settings on Pascal-VOC in Table-2 (as shown below) is really interesting: on the last column of Table-2, all the methods only use 92 images as the labeled set and choose the train-aug set (10582) as the unlabeled set according to the code :

    https://github.com/googleinterns/wss/blob/8069dbe8b68b409a891224508f35c6ae5ecec4c9/core/data_generator.py#L85-L104

    and,

    https://github.com/googleinterns/wss/blob/280cc1a6ceb5326044ee7521706d3d293c4aeb40/train_wss.py#L796

    Our understanding is that the FLAGS.train_split_cls represents the set of unlabeled images used for training and its value is train_aug by default. So the number of unlabeled images is nearly more than 100x than the number of unlabeled images. Given that the total training iteration number is set as training_number_of_steps=30000, therefore, we will iterate the sampled 92 labeled images for nearly 30000x64/92=20869 epochs. Is my understanding correct?

    If my understanding is correct, we are curious about whether training for so many epochs on the 92 labeled images is a good choice. Besides, as the train-aug set (10582) contains the 92 labeled images, so we guess all the methods also apply the pseudo-label based methods/consistency based methods on the labeled images (instead of only on the unlabeled images).

    Great thanks and wait for your explanation if my understanding is wrong!

    image

    opened by PkuRainBow 4
  • About the device

    About the device

    Hi, thx for your work! I am trying to reproduce your work on 2 gpu, so I run the following comand:

    python train_wss.py
    --logtostderr
    --train_split="8_clean"
    --train_split_cls="train_aug"
    --model_variant="xception_65"
    --atrous_rates=6
    --atrous_rates=12
    --atrous_rates=18
    --output_stride=16
    --decoder_output_stride=4
    --train_crop_size="513,513"
    --num_clones=2
    --train_batch_size=8
    --training_number_of_steps="30000"
    --fine_tune_batch_norm=true
    --tf_initial_checkpoint="/wss-master/init/deeplabv3_xception_2018_01_04/model.ckpt"
    --train_logdir="/wss-master/train_log"
    --dataset_dir="/wss-master/data/pascal_voc_seg"

    But it seems that the gpus are not utilized, the code just run on cpus, I'm not quite familiar with tensorflow, could you plz give me some hints to solve this problem? Thx!

    opened by revaeb 3
  • More data splits information

    More data splits information

    Hi, thanks for sharing! I have two questions about the data splits:

    1. Could you please supply 2_clean.txt for PASCAL VOC?
    2. Could you tell me which split of 1/16 data @VOC did you use in Table 2 in the main paper? Since you have supplied 3 different splits of 1/16 data in this repo, which makes me confused.

    Thanks again!

    opened by charlesCXK 3
  • Influence of the color jittering parameters

    Influence of the color jittering parameters

    Great work! We find that you apply only the color jittering augmentation as the strong augmentation. So we are very interested in the influence of the choice of the color jittering parameters.

    For example, the default setting in the release code is,

    https://github.com/googleinterns/wss/blob/8069dbe8b68b409a891224508f35c6ae5ecec4c9/core/preprocess_utils.py#L715-L718

    According to the previous SimCLR paper, we know they set them as follows:

      brightness = 0.8
      contrast = 0.8
      saturation = 0.8
      hue = 0.2
    

    It would be great if you could share more results of the influence on the choices of these four hyperparameters!

    opened by PkuRainBow 2
  • Security Policy violation Binary Artifacts

    Security Policy violation Binary Artifacts

    This issue was automatically created by Allstar.

    Security Policy Violation Project is out of compliance with Binary Artifacts policy: binaries present in source code

    Rule Description Binary Artifacts are an increased security risk in your repository. Binary artifacts cannot be reviewed, allowing the introduction of possibly obsolete or maliciously subverted executables. For more information see the Security Scorecards Documentation for Binary Artifacts.

    Remediation Steps To remediate, remove the generated executable artifacts from the repository.

    First 10 Artifacts Found

    • third_party/deeplab/pycache/init.cpython-36.pyc
    • third_party/deeplab/pycache/common.cpython-36.pyc
    • third_party/deeplab/core/pycache/init.cpython-36.pyc
    • third_party/deeplab/core/pycache/conv2d_ws.cpython-36.pyc
    • third_party/deeplab/core/pycache/dense_prediction_cell.cpython-36.pyc
    • third_party/deeplab/core/pycache/feature_extractor.cpython-36.pyc
    • third_party/deeplab/core/pycache/nas_cell.cpython-36.pyc
    • third_party/deeplab/core/pycache/nas_genotypes.cpython-36.pyc
    • third_party/deeplab/core/pycache/nas_network.cpython-36.pyc
    • third_party/deeplab/core/pycache/preprocess_utils.cpython-36.pyc
    • Run a Scorecards scan to see full list.

    Additional Information This policy is drawn from Security Scorecards, which is a tool that scores a project's adherence to security best practices. You may wish to run a Scorecards scan directly on this repository for more details.


    Allstar has been installed on all Google managed GitHub orgs. Policies are gradually being rolled out and enforced by the GOSST and OSPO teams. Learn more at http://go/allstar

    This issue will auto resolve when the policy is in compliance.

    Issue created by Allstar. See https://github.com/ossf/allstar/ for more information. For questions specific to the repository, please contact the owner or maintainer.

    allstar 
    opened by allstar-app[bot] 109
  • About the results showed in talbe 1

    About the results showed in talbe 1

    Hi, sorry to bother you and thx for your sharing. Could you plz tell me how to set the --train_split="${SPLIT}" when I want to reproduce your results shown in table1(the result of semi-supervised setting with 1.4k as labeled data)? Should it be --train_split="8_clean"? Or this split is for the low-data setting? Thanks for your help!

    opened by revaeb 6
  • class problem of the addational data training

    class problem of the addational data training

    Thanks for the great work. In section 4.5 of the published paper, the different datasets were combined to train a model in a supervised way. However, the class labels of the different datasets are different, i.e. COCO, VOC, and Cityscapes, so how to figure out this problem during training? Looking forward to your reply.

    opened by Pei233 1
  • Segmentation fault

    Segmentation fault

    Hi, Thank you so much for your work! I'd like to try it on a different dataset and I was wondering if you could guide me through the most important things that I have to prepare to be able to run your code? I started with the most basic thing. I created a dataset directory and downloaded the pre-created tfrecords for voc12 put them in dataset. I wanted to try the training on one GPU, so I ran python3 train_sup.py --num_clones 1 --train_logdir logs/ --dataset_dir dataset/ but I am getting segmentation faulterror. What do you think I am doing wrong?

    Thank you so much in advance!

    opened by saramsv 17
Owner
Google Interns
Google Interns
ST++: Make Self-training Work Better for Semi-supervised Semantic Segmentation

ST++ This is the official PyTorch implementation of our paper: ST++: Make Self-training Work Better for Semi-supervised Semantic Segmentation. Lihe Ya

Lihe Yang 147 Jan 3, 2023
ISBI 2022: Cross-level Contrastive Learning and Consistency Constraint for Semi-supervised Medical Image.

Cross-level Contrastive Learning and Consistency Constraint for Semi-supervised Medical Image Introduction This repository contains the PyTorch implem

null 25 Nov 9, 2022
Self-supervised Augmentation Consistency for Adapting Semantic Segmentation (CVPR 2021)

Self-supervised Augmentation Consistency for Adapting Semantic Segmentation This repository contains the official implementation of our paper: Self-su

Visual Inference Lab @TU Darmstadt 132 Dec 21, 2022
[ICCV 2021] A Simple Baseline for Semi-supervised Semantic Segmentation with Strong Data Augmentation

[ICCV 2021] A Simple Baseline for Semi-supervised Semantic Segmentation with Strong Data Augmentation

CodingMan 45 Dec 12, 2022
Learning Pixel-level Semantic Affinity with Image-level Supervision for Weakly Supervised Semantic Segmentation, CVPR 2018

Learning Pixel-level Semantic Affinity with Image-level Supervision This code is deprecated. Please see https://github.com/jiwoon-ahn/irn instead. Int

Jiwoon Ahn 337 Dec 15, 2022
Shape-aware Semi-supervised 3D Semantic Segmentation for Medical Images

SASSnet Code for paper: Shape-aware Semi-supervised 3D Semantic Segmentation for Medical Images(MICCAI 2020) Our code is origin from UA-MT You can fin

klein 125 Jan 3, 2023
Anti-Adversarially Manipulated Attributions for Weakly and Semi-Supervised Semantic Segmentation (CVPR 2021)

Anti-Adversarially Manipulated Attributions for Weakly and Semi-Supervised Semantic Segmentation Input Image Initial CAM Successive Maps with adversar

Jungbeom Lee 110 Dec 7, 2022
[CVPR 2021] Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision

TorchSemiSeg [CVPR 2021] Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision by Xiaokang Chen1, Yuhui Yuan2, Gang Zeng1, Jingdong Wang

Chen XiaoKang 387 Jan 8, 2023
[CVPR 2022] Semi-Supervised Semantic Segmentation Using Unreliable Pseudo-Labels

Using Unreliable Pseudo Labels Official PyTorch implementation of Semi-Supervised Semantic Segmentation Using Unreliable Pseudo Labels, CVPR 2022. Ple

Haochen Wang 268 Dec 24, 2022
[cvpr22] Perturbed and Strict Mean Teachers for Semi-supervised Semantic Segmentation

PS-MT [cvpr22] Perturbed and Strict Mean Teachers for Semi-supervised Semantic Segmentation by Yuyuan Liu, Yu Tian, Yuanhong Chen, Fengbei Liu, Vasile

Yuyuan Liu 132 Jan 3, 2023
[CVPR 2021] MiVOS - Mask Propagation module. Reproduced STM (and better) with training code :star2:. Semi-supervised video object segmentation evaluation.

MiVOS (CVPR 2021) - Mask Propagation Ho Kei Cheng, Yu-Wing Tai, Chi-Keung Tang [arXiv] [Paper PDF] [Project Page] [Papers with Code] This repo impleme

Rex Cheng 106 Jan 3, 2023
Semi Supervised Learning for Medical Image Segmentation, a collection of literature reviews and code implementations.

Semi-supervised-learning-for-medical-image-segmentation. Recently, semi-supervised image segmentation has become a hot topic in medical image computin

Healthcare Intelligence Laboratory 1.3k Jan 3, 2023
Official code of Retinal Vessel Segmentation with Pixel-wise Adaptive Filters and Consistency Training

Official code of Retinal Vessel Segmentation with Pixel-wise Adaptive Filters and Consistency Training (ISBI 2022)

anonymous 7 Feb 10, 2022
Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation)

Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation) Download Synthia dataset The model uses

null 32 Sep 21, 2022
Code for the paper One Thing One Click: A Self-Training Approach for Weakly Supervised 3D Semantic Segmentation, CVPR 2021.

One Thing One Click One Thing One Click: A Self-Training Approach for Weakly Supervised 3D Semantic Segmentation (CVPR2021) Code for the paper One Thi

null 44 Dec 12, 2022
UniMoCo: Unsupervised, Semi-Supervised and Full-Supervised Visual Representation Learning

UniMoCo: Unsupervised, Semi-Supervised and Full-Supervised Visual Representation Learning This is the official PyTorch implementation for UniMoCo pape

dddzg 49 Jan 2, 2023
Project looking into use of autoencoder for semi-supervised learning and comparing data requirements compared to supervised learning.

Project looking into use of autoencoder for semi-supervised learning and comparing data requirements compared to supervised learning.

Tom-R.T.Kvalvaag 2 Dec 17, 2021
Hybrid CenterNet - Hybrid-supervised object detection / Weakly semi-supervised object detection

Hybrid-Supervised Object Detection System Object detection system trained by hybrid-supervision/weakly semi-supervision (HSOD/WSSOD): This project is

null 5 Dec 10, 2022
sssegmentation is a general framework for our research on strongly supervised semantic segmentation.

sssegmentation is a general framework for our research on strongly supervised semantic segmentation.

null 445 Jan 2, 2023