Exploring Cross-Image Pixel Contrast for Semantic Segmentation

Overview

Exploring Cross-Image Pixel Contrast for Semantic Segmentation

Exploring Cross-Image Pixel Contrast for Semantic Segmentation,
Wenguan Wang, Tianfei Zhou, Fisher Yu, Jifeng Dai, Ender Konukoglu and Luc Van Gool
arXiv technical report (arXiv 2101.11939)

Abstract

Current semantic segmentation methods focus only on mining “local” context, i.e., dependencies between pixels within individual images, by context-aggregation modules (e.g., dilated convolution, neural attention) or structureaware optimization criteria (e.g., IoU-like loss). However, they ignore “global” context of the training data, i.e., rich semantic relations between pixels across different images. Inspired by the recent advance in unsupervised contrastive representation learning, we propose a pixel-wise contrastive framework for semantic segmentation in the fully supervised setting. The core idea is to enforce pixel embeddings belonging to a same semantic class to be more similar than embeddings from different classes. It raises a pixel-wise metric learning paradigm for semantic segmentation, by explicitly exploring the structures of labeled pixels, which are long ignored in the field. Our method can be effortlessly incorporated into existing segmentation frameworks without extra overhead during testing.

We experimentally show that, with famous segmentation models (i.e., DeepLabV3, HRNet, OCR) and backbones (i.e., ResNet, HRNet), our method brings consistent performance improvements across diverse datasets (i.e., Cityscapes, PASCALContext, COCO-Stuff).

Installation

This implementation is built on openseg.pytorch. Many thanks to the authors for the efforts.

Please follow the Getting Started for installation and dataset preparation.

Running

Cityscapes

  1. Train DeepLabV3

    bash scripts/cityscapes/deeplab/run_r_101_d_8_deeplabv3_train_contrast.sh train 'resnet101-deeplabv3-contrast'

Features (in progress)

  • Pixel-wise Contrastive Loss
  • Hard Anchor Sampling
  • Memory Bank
  • Hard Example Mining
  • Model Zoo

t-SNE Visualization

  • Pixel-wise Cross-Entropy Loss

  • Pixel-wise Contrastive Learning Objective

Citation

@article{wang2021exploring,
  title   = {Exploring Cross-Image Pixel Contrast for Semantic Segmentation},
  author  = {Wang, Wenguan and Zhou, Tianfei and Yu, Fisher and Dai, Jifeng and Konukoglu, Ender and Van Gool, Luc},
  journal = {arXiv preprint arXiv:2101.11939},
  year    = {2021}
}
Comments
  • Problem in contrastive loss

    Problem in contrastive loss

    Hi, Dr. Zhou,

    Thanks for releasing the code. When reading the code about the contrastive loss in function _contrastive(), a mask is computed by following two lines: https://github.com/tfzhou/ContrastiveSeg/blob/2ab84d8ec679adc7f7be1853c8684b44bf899273/lib/loss/loss_contrast_mem.py#L124 and https://github.com/tfzhou/ContrastiveSeg/blob/2ab84d8ec679adc7f7be1853c8684b44bf899273/lib/loss/loss_contrast_mem.py#L131

    Now I think the shape of the mask is [anchor_num * anchor_count, class_num * cache_size]. If I did not misunderstand the code, the mask is a 'positive' mask, and each line represents the positive samples of an anchor view.

    Then in L134-L138, the function of logits_mask is confusing: https://github.com/tfzhou/ContrastiveSeg/blob/2ab84d8ec679adc7f7be1853c8684b44bf899273/lib/loss/loss_contrast_mem.py#L134-L138 Could you please explain these lines?

    Suppose I have anchor_num=6 (2 images, 3 valid classes per image), anchor_count=2 (sample two pixels per class), class_num=5 (class number), cache_size=2 (memory size), then the following code raises RuntimeError:

    mask = torch.ones((6 * 2, 5 * 2)).scatter_(1, torch.arange(6 * 2).view(-1, 1), 0)
    

    Output:

    Traceback (most recent call last):
        File "<stdin>", line 1, in <module>
    RuntimeError: index 10 is out of bounds for dimension 1 with size 10
    
    opened by Jarvis73 9
  • Can't find the code where it is saving loss weights during training

    Can't find the code where it is saving loss weights during training

    Hi @tfzhou

    I'm studying your repo for a couple of months from time to time. I'm implementing something different. However, I need to change the loss function and I'm not sure if it will be saved during the training. I couldn't find the lines where the loss function is saved. In init functions, the weights are loaded from configer, but I couldn't find where they are saved. At each epoch, the losses should be saved right? Do I think correctly?

    opened by faruknane 6
  • Question about the subtraction of max from inner product in the loss computation

    Question about the subtraction of max from inner product in the loss computation

    https://github.com/tfzhou/ContrastiveSeg/blob/2ab84d8ec679adc7f7be1853c8684b44bf899273/lib/loss/loss_contrast.py#L106

    Would you mind explain why is it needed to subtract max of the inner product from each of the inner product per anchor?

    opened by edwardpwtsoi 5
  • code confusing me in def._sample_negative

    code confusing me in def._sample_negative

    as the code proposed in loss_contrast_mem.py

    def _sample_negative(self, Q): 
            class_num, cache_size, feat_size = Q.shape
            X_ = torch.zeros((class_num * cache_size, feat_size)).float().cuda()
            y_ = torch.zeros((class_num * cache_size, 1)).float().cuda()
            sample_ptr = 0
            for ii in range(class_num):
                if ii == 0: continue   ####?
                this_q = Q[ii, :cache_size, :]
                X_[sample_ptr:sample_ptr + cache_size, ...] = this_q
                y_[sample_ptr:sample_ptr + cache_size, ...] = ii
                sample_ptr += cache_size
    
            return X_, y_
    
    

    I'm wondering what if ii == 0: continue is for? The function of this line seems to be jump over the class 0 in memory bank, and adding much zeros in the final loop. A possible explaination shall be that class0 is bg, but I check the code and I suppose not to be this case. So can anyone explain this for me? Thanks!

    opened by ck6698000 3
  • about using your model in other dataset.

    about using your model in other dataset.

    Hi,sir. Thank you for release the great code. I try to use your model in my dataset. It is about optic cup and disc segmentation.The dataset has 3 classes. I rewrite .shand .json :,"num_classes": 3, "label_list": [1,2,3], Here are two errors,can you help ?Thank u! RuntimeError: weight tensor should be defined either for all or no classes Traceback (most recent call last): File "/data/fyw/ContrastiveOpticSegmentation/main_contrastive.py", line 236, in <module> model.train() File "/data/fyw/ContrastiveOpticSegmentation/segmentor/trainer_contrastive.py", line 420, in train self.__train() File "/data/fyw/ContrastiveOpticSegmentation/segmentor/trainer_contrastive.py", line 241, in __train loss = self.pixel_loss(outputs, targets, with_embed=with_embed) File "/data/anaconda3/envs/ContrastiveOpticSegmentation/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/data/fyw/ContrastiveOpticSegmentation/lib/loss/loss_contrast_mem.py", line 218, in forward loss = self.seg_criterion(pred, target) File "/data/anaconda3/envs/ContrastiveOpticSegmentation/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/data/fyw/ContrastiveOpticSegmentation/lib/loss/loss_helper.py", line 204, in forward loss = self.ce_loss(inputs, target) File "/data/anaconda3/envs/ContrastiveOpticSegmentation/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/data/anaconda3/envs/ContrastiveOpticSegmentation/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1150, in forward return F.cross_entropy(input, target, weight=self.weight, File "/data/anaconda3/envs/ContrastiveOpticSegmentation/lib/python3.8/site-packages/torch/nn/functional.py", line 2846, in cross_entropy return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) 2.Traceback (most recent call last): File "main_contrastive.py", line 185, in <module> handle_distributed(args_parser, os.path.expanduser(os.path.abspath(__file__))) File "/data/fyw/ContrastiveOpticSegmentation/lib/utils/distributed.py", line 70, in handle_distributed raise subprocess.CalledProcessError(returncode=process.returncode, subprocess.CalledProcessError: Command '['/data/anaconda3/envs/ContrastiveOpticSegmentation/bin/python', '-u', '-m', 'torch.distributed.launch', '--nproc_per_node', '4', '--master_port', '29962', '/data/fyw/ContrastiveOpticSegmentation/main_contrastive.py', '--configs', 'configs/REFUGE/H_48_D_4_MEM.json', '--drop_last', 'y', '--phase', 'train', '--gathered', 'n', '--loss_balance', 'y', '--log_to_file', 'n', '--backbone', 'hrnet48', '--model_name', 'hrnet_w48_mem', '--gpu', '4', '5', '6', '7', '--data_dir', '/data/dataset/REFUGE', '--loss_type', 'mem_contrast_ce_loss', '--max_iters', '40000', '--train_batch_size', '8', '--checkpoints_root', '/data/fyw/ContrastiveOpticSegmentation/Model/REFUGE/', '--checkpoints_name', 'hrnet_w48_mem_paddle_lr2x_1', '--pretrained', '/data/dataset/hrnetv2_w48_imagenet_pretrained.pth', '--distributed', '--base_lr', '0.01']' returned non-zero exit status 1.

    opened by Lemonweier 3
  • Questions about T-SNE visualization

    Questions about T-SNE visualization

    Hello, Nice work. I am curious about how to visualize embeddings on a 2D plot with t-sne, and where the embeddings come from.

    The 1st question is about the embedding collection. Specifically, there are two ways collecting embeddings from images(or feature maps).

    • collect embeddings from only one val image or val images in a single mini batch.
    • construct a memory bank to collect feature pixels across all val images, then randomly select N samples as t-sne inputs.

    Next question is that the embeddings collected above come from the output of last conv layer before the classifier or the MLP projection head?

    opened by dongdongtong 3
  • why ignore zero?

    why ignore zero?

    Hi, thanks for your outstanding work. I have a question when researching your code. Why ignore it when class id equals zero?

    https://github.com/tfzhou/ContrastiveSeg/blob/7be326ae9127a847558584eac489d4858adf7cb1/lib/loss/loss_contrast_mem.py#L98

    https://github.com/tfzhou/ContrastiveSeg/blob/7be326ae9127a847558584eac489d4858adf7cb1/segmentor/trainer_contrastive.py#L114

    opened by Junshan233 3
  • Explanation of n_view

    Explanation of n_view

    Hi, Thanks for your great work.

    Could you help explain what is the n_view stands for? https://github.com/tfzhou/ContrastiveSeg/blob/310120712ad4b6ecea45c3ac1143118f959c86e4/lib/loss/loss_contrast.py#L47-L51

    I understand the meaning of the total_classes and feat_dim, but it's difficult for me to understand the meaning of n_view here. Thanks.

    opened by HenryPengZou 3
  • Problem in the function

    Problem in the function "_dequeue_and_enqueue"

    Hello, I have one problem at the code pixel_queue_ptr[lb] = (pixel_queue_ptr[lb] + 1) % self.memory_size in the line 138, _dequeue_and_enqueue function, trainer_contrastive.py file. Should pixel_queue_ptr[lb] + 1 be modified to pixel_queue_ptr[lb] + K? Otherwise, pixel_queue[lb, ptr + 1:ptr + 1 + K, :] will be assigned at the next iteration, which is overlapped with pixel_queue[lb, ptr:ptr + K, :].

    opened by eezywu 3
  • The L2 normalization of features

    The L2 normalization of features

    Dear author,

    I read the code and found that all the embeddings in the loss function are L2-normalized in the projection head, but regarding the dequeue_and_enqueue function, why are the features normalized again at L122 and L134 in trainer_contrastive.py?

    opened by ccccly 2
  • network_stride &  _dequeue_and_enqueue

    network_stride & _dequeue_and_enqueue

    Thank you very much for your excellent work, I have a few questions to consult you.

    1. What is the function of this parameter network_stride, and why do you perform this operation on the labels?

    2. I want to know what are the dimensions of the two parameters(keys, labels) passed by the function _dequeue_and_enqueue, and are they the same?

    3. Can your code handle segmentation tasks with labels starting from 0? I didn't understand the statement this_label_ids = [x for x in this_label_ids if x > 0] in the function _dequeue_and_enqueue

    opened by Darcy103 2
  • problem with resuming training from checkpoint

    problem with resuming training from checkpoint

    Hi... I am getting following error while resuming training from a checkpoint on a single GPU system. The training went fine when started from 0th iteration, but exited immediately after loading a checkpoint. The relevant excerpt that I have modified in main.py for that purpose is also shown below. Is it a bug or there's some mistake somewhere?

    (command used) sh scripts/cityscapes/ocrnet/run_r_101_d_8_ocrnet_train.sh resume x3

    (modifications in main.py: ignore single quotes typed in here for proper display) elif [ "$1"x == "resume"x ]; then ${PYTHON} -u main.py --configs '$'{CONFIGS} \ --drop_last y \ --phase train \ --gathered n \ --loss_balance y \ --log_to_file n \ --backbone ${BACKBONE} \ --model_name ${MODEL_NAME} \ --max_iters ${MAX_ITERS} \ --data_dir ${DATA_DIR} \ --loss_type ${LOSS_TYPE} \ --resume_continue y \ --resume ${CHECKPOINTS_ROOT}/checkpoints/bottle/'$'{CHECKPOINTS_NAME}_latest.pth \ --checkpoints_name ${CHECKPOINTS_NAME} \ --distributed False \ 2>&1 | tee -a ${LOG_FILE} #--gpu 0 1 2 3 **

    2022-11-16 11:30:47,097 INFO [module_runner.py, 87] Loading checkpoint from /workspace/data/defGen/graphics/Pre_CL_x3//..//checkpoints/bottle/spatial_ocrnet_deepbase_resnet101_dilated8_x3_latest.pth... 2022-11-16 11:30:47,283 INFO [trainer.py, 90] Params Group Method: None 2022-11-16 11:30:47,285 INFO [optim_scheduler.py, 96] Use lambda_poly policy with default power 0.9 2022-11-16 11:30:47,285 INFO [data_loader.py, 132] use the DefaultLoader for train... 2022-11-16 11:30:47,773 INFO [default_loader.py, 38] train 501 2022-11-16 11:30:47,774 INFO [data_loader.py, 164] use DefaultLoader for val ... 2022-11-16 11:30:47,873 INFO [default_loader.py, 38] val 126 2022-11-16 11:30:47,873 INFO [loss_manager.py, 66] use loss: fs_auxce_loss. 2022-11-16 11:30:47,874 INFO [loss_manager.py, 55] use DataParallelCriterion loss 2022-11-16 11:30:48,996 INFO [data_helper.py, 126] Input keys: ['img'] 2022-11-16 11:30:48,996 INFO [data_helper.py, 127] Target keys: ['labelmap'] Traceback (most recent call last): File "main.py", line 227, in model.train() File "/workspace/defGen/External/ContrastiveSeg-main/segmentor/trainer.py", line 390, in train self.__train() File "/workspace/defGen/External/ContrastiveSeg-main/segmentor/trainer.py", line 196, in __train backward_loss = display_loss = self.pixel_loss(outputs, targets, File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/workspace/defGen/External/ContrastiveSeg-main/lib/extensions/parallel/data_parallel.py", line 125, in forward return self.module(inputs[0], *targets[0], **kwargs[0]) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/workspace/defGen/External/ContrastiveSeg-main/lib/loss/loss_helper.py", line 309, in forward seg_loss = self.ce_loss(seg_out, targets) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/workspace/defGen/External/ContrastiveSeg-main/lib/loss/loss_helper.py", line 203, in forward target = self._scale_target(targets[0], (inputs.size(2), inputs.size(3))) IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

    opened by mailtohrishi 0
  • Example for max_views, max_samples?

    Example for max_views, max_samples?

    First of all, thanks a lot for you brilliant work. I really appreciate it. I have one quick question and would be glad if I could get an answer.

    How do we determine the correct values for max_samples and max_views in the _hard_anchor_sampling() function?? I have a dataset with 2 classes (background, class_1) and around 1000 images of size 256x256. What exactly is meant by max_samples and max_views? If you can provide a concrete example, I would very grateful.

    opened by lennart-maack 0
  • Semi-Hard Example Sampling implementation

    Semi-Hard Example Sampling implementation

    Hello,

    I would like to ask if there is available the implementation of sampling hard positive/negatives pixel for the computation of the contrastive loss, because I only found the implementation of hard anchors sampling.

    Thanks

    opened by nysp78 0
  • How to reproduce the SoTA result of Cityscapes?

    How to reproduce the SoTA result of Cityscapes?

    Hello, thanks for your impressive work! However, I am not able to reproduce the results of Cityscapes using this repo. In the paper, the mIOU of HRNetv2 increased to 81.4, and the mIOU of HRNetv2+OCR increased to 83.2 on the Cityscapes dataset. When I ran the script "run_h_48_d_4_contrast_mem.sh", the mIOU is hard to converge. Also I notice that the script of HRNet+OCR+Contrastive based on memory bank is not given, could you suggest me how to achieve the experimental results in the paper?

    opened by JasiRose 1
  •  Momentum Update

    Momentum Update

    First of all, I would like to say that this is a very milestone work. But I need to solve several problems so that I can better understand it.

    1. In loss_ contras_ mem file, I only see encode_ q. Is it actually encode_ K is not used in the paper.

    2. What is the difference between ContrastAuxCELoss and ContrastCELoss.

    Looking forward to your reply. Thank you very much!!!!!!

    opened by Macc520 2
【Arxiv】Exploring Separable Attention for Multi-Contrast MR Image Super-Resolution

SANet Exploring Separable Attention for Multi-Contrast MR Image Super-Resolution Dependencies numpy==1.18.5 scikit_image==0.16.2 torchvision==0.8.1 to

null 36 Jan 5, 2023
Learning Pixel-level Semantic Affinity with Image-level Supervision for Weakly Supervised Semantic Segmentation, CVPR 2018

Learning Pixel-level Semantic Affinity with Image-level Supervision This code is deprecated. Please see https://github.com/jiwoon-ahn/irn instead. Int

Jiwoon Ahn 337 Dec 15, 2022
The implementation of "Bootstrapping Semantic Segmentation with Regional Contrast".

ReCo - Regional Contrast This repository contains the source code of ReCo and baselines from the paper, Bootstrapping Semantic Segmentation with Regio

Shikun Liu 128 Dec 30, 2022
Propagate Yourself: Exploring Pixel-Level Consistency for Unsupervised Visual Representation Learning, CVPR 2021

Propagate Yourself: Exploring Pixel-Level Consistency for Unsupervised Visual Representation Learning By Zhenda Xie*, Yutong Lin*, Zheng Zhang, Yue Ca

Zhenda Xie 293 Dec 20, 2022
Per-Pixel Classification is Not All You Need for Semantic Segmentation

MaskFormer: Per-Pixel Classification is Not All You Need for Semantic Segmentation Bowen Cheng, Alexander G. Schwing, Alexander Kirillov [arXiv] [Proj

Facebook Research 1k Jan 8, 2023
Pytorch Implementation for NeurIPS (oral) paper: Pixel Level Cycle Association: A New Perspective for Domain Adaptive Semantic Segmentation

Pixel-Level Cycle Association This is the Pytorch implementation of our NeurIPS 2020 Oral paper Pixel-Level Cycle Association: A New Perspective for D

null 87 Oct 19, 2022
Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation)

Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation) Download Synthia dataset The model uses

null 32 Sep 21, 2022
Segmentation in Style: Unsupervised Semantic Image Segmentation with Stylegan and CLIP

Segmentation in Style: Unsupervised Semantic Image Segmentation with Stylegan and CLIP Abstract: We introduce a method that allows to automatically se

Daniil Pakhomov 134 Dec 19, 2022
[CVPR 2021] Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision

TorchSemiSeg [CVPR 2021] Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision by Xiaokang Chen1, Yuhui Yuan2, Gang Zeng1, Jingdong Wang

Chen XiaoKang 387 Jan 8, 2023
PyTorch implementation of "Contrast to Divide: self-supervised pre-training for learning with noisy labels"

Contrast to Divide: self-supervised pre-training for learning with noisy labels This is an official implementation of "Contrast to Divide: self-superv

null 55 Nov 23, 2022
pytorch implementation of "Contrastive Multiview Coding", "Momentum Contrast for Unsupervised Visual Representation Learning", and "Unsupervised Feature Learning via Non-Parametric Instance-level Discrimination"

Unofficial implementation: MoCo: Momentum Contrast for Unsupervised Visual Representation Learning (Paper) InsDis: Unsupervised Feature Learning via N

Zhiqiang Shen 16 Nov 4, 2020
This repository is the official implementation of Unleashing the Power of Contrastive Self-Supervised Visual Models via Contrast-Regularized Fine-Tuning (NeurIPS21).

Core-tuning This repository is the official implementation of ``Unleashing the Power of Contrastive Self-Supervised Visual Models via Contrast-Regular

vanint 18 Dec 17, 2022
The code for MM2021 paper "Multi-Level Counterfactual Contrast for Visual Commonsense Reasoning"

The Code for MM2021 paper "Multi-Level Counterfactual Contrast for Visual Commonsense Reasoning" Setting up and using the repo Get the dataset. Follow

null 4 Apr 20, 2022
Codes for the paper Contrast and Mix: Temporal Contrastive Video Domain Adaptation with Background Mixing

Contrast and Mix (CoMix) The repository contains the codes for the paper Contrast and Mix: Temporal Contrastive Video Domain Adaptation with Backgroun

Computer Vision and Intelligence Research (CVIR) 13 Dec 10, 2022
Official repository of the AAAI'2022 paper "Contrast and Generation Make BART a Good Dialogue Emotion Recognizer"

CoG-BART Contrast and Generation Make BART a Good Dialogue Emotion Recognizer Quick Start: To run the model on test sets of four datasets, Download th

null 39 Dec 24, 2022
Cross Quality LFW: A database for Analyzing Cross-Resolution Image Face Recognition in Unconstrained Environments

Cross-Quality Labeled Faces in the Wild (XQLFW) Here, we release the database, evaluation protocol and code for the following paper: Cross Quality LFW

Martin Knoche 10 Dec 12, 2022
Pixel Consensus Voting for Panoptic Segmentation (CVPR 2020)

Implementation for Pixel Consensus Voting (CVPR 2020). This codebase contains the essential ingredients of PCV, including various spatial discretizati

Haochen 23 Oct 25, 2022
Pixel-wise segmentation on VOC2012 dataset using pytorch.

PiWiSe Pixel-wise segmentation on the VOC2012 dataset using pytorch. FCN SegNet PSPNet UNet RefineNet For a more complete implementation of segmentati

Bodo Kaiser 378 Dec 30, 2022
Retinal Vessel Segmentation with Pixel-wise Adaptive Filters (ISBI 2022)

Retinal Vessel Segmentation with Pixel-wise Adaptive Filters (ISBI 2022) Introdu

anonymous 14 Oct 27, 2022