PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)

Overview

SupContrast: Supervised Contrastive Learning

This repo covers an reference implementation for the following papers in PyTorch, using CIFAR as an illustrative example:
(1) Supervised Contrastive Learning. Paper
(2) A Simple Framework for Contrastive Learning of Visual Representations. Paper

Loss Function

The loss function SupConLoss in losses.py takes features (L2 normalized) and labels as input, and return the loss. If labels is None or not passed to the it, it degenerates to SimCLR.

Usage:

from losses import SupConLoss

# define loss with a temperature `temp`
criterion = SupConLoss(temperature=temp)

# features: [bsz, n_views, f_dim]
# `n_views` is the number of crops from each image
# better be L2 normalized in f_dim dimension
features = ...
# labels: [bsz]
labels = ...

# SupContrast
loss = criterion(features, labels)
# or SimCLR
loss = criterion(features)
...

Comparison

Results on CIFAR-10:

Arch Setting Loss Accuracy(%)
SupCrossEntropy ResNet50 Supervised Cross Entropy 95.0
SupContrast ResNet50 Supervised Contrastive 96.0
SimCLR ResNet50 Unsupervised Contrastive 93.6

Results on CIFAR-100:

Arch Setting Loss Accuracy(%)
SupCrossEntropy ResNet50 Supervised Cross Entropy 75.3
SupContrast ResNet50 Supervised Contrastive 76.5
SimCLR ResNet50 Unsupervised Contrastive 70.7

Results on ImageNet (Stay tuned):

Arch Setting Loss Accuracy(%)
SupCrossEntropy ResNet50 Supervised Cross Entropy -
SupContrast ResNet50 Supervised Contrastive 79.1 (MoCo trick)
SimCLR ResNet50 Unsupervised Contrastive -

Running

You might use CUDA_VISIBLE_DEVICES to set proper number of GPUs, and/or switch to CIFAR100 by --dataset cifar100.
(1) Standard Cross-Entropy

python main_ce.py --batch_size 1024 \
  --learning_rate 0.8 \
  --cosine --syncBN \

(2) Supervised Contrastive Learning
Pretraining stage:

python main_supcon.py --batch_size 1024 \
  --learning_rate 0.5 \
  --temp 0.1 \
  --cosine

You can also specify --syncBN but I found it not crucial for SupContrast (syncBN 95.9% v.s. BN 96.0%).
Linear evaluation stage:

python main_linear.py --batch_size 512 \
  --learning_rate 5 \
  --ckpt /path/to/model.pth

(3) SimCLR
Pretraining stage:

python main_supcon.py --batch_size 1024 \
  --learning_rate 0.5 \
  --temp 0.5 \
  --cosine --syncBN \
  --method SimCLR

The --method SimCLR flag simply stops labels from being passed to SupConLoss criterion. Linear evaluation stage:

python main_linear.py --batch_size 512 \
  --learning_rate 1 \
  --ckpt /path/to/model.pth

On custom dataset:

python main_supcon.py --batch_size 1024 \
  --learning_rate 0.5  \ 
  --temp 0.1 --cosine \
  --dataset path \
  --data_folder ./path \
  --mean "(0.4914, 0.4822, 0.4465)" \
  --std "(0.2675, 0.2565, 0.2761)" \
  --method SimCLR

The --data_folder must be of form ./path/label/xxx.png folowing https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.ImageFolder convension.

and

t-SNE Visualization

(1) Standard Cross-Entropy

(2) Supervised Contrastive Learning

(3) SimCLR

Reference

@Article{khosla2020supervised,
    title   = {Supervised Contrastive Learning},
    author  = {Prannay Khosla and Piotr Teterwak and Chen Wang and Aaron Sarna and Yonglong Tian and Phillip Isola and Aaron Maschinot and Ce Liu and Dilip Krishnan},
    journal = {arXiv preprint arXiv:2004.11362},
    year    = {2020},
}
Comments
  • loss saturate after several iteration

    loss saturate after several iteration

    Hi, thanks for your sharing. I try to test your loss funtion, but is constant after some times.

    the problem is that learning is going to wrong direction. dot product of model's output is all, it means all vector is grouped together regardless of label?

    what should I check?? thanks.

    -> loss 1.9459104537963867 -> label tensor([3, 5, 0, 1])
    -> dot production of augmented data's output tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],

    1.9459099769592285 tensor([4, 4, 3, 3]) tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]], grad_fn=)

    opened by YuBeomGon 10
  • Input image resolution for ImageNet

    Input image resolution for ImageNet

    Hi,

    I rerun the repo code with my 2 x 32GB V100 GPU.

    If I use 224 by 224 size image, my GPU only can fit 10 images in a batch totally.

    What is the resolution you use in your paper for ImageNet?

    opened by littleredxh 8
  • Visualization results with t-sne

    Visualization results with t-sne

    Hello. Thank you for providing me with great code. I would like to see the results of the visualization when I actually run it myself. It would be great if you could provide me with the code to visualize using t-sne. Thank you for your help.

    opened by Hiiragi0107 7
  • SyncBN?

    SyncBN?

    Hi @HobbitLong

    I see you use SyncBN from apex to train with DataParallel, however, SyncBN seems to be designed with DistributedDataParallel. Could you please confirm if SyncBN works in this case?

    Best, Jizong

    opened by jizongFox 6
  • Training time

    Training time

    Hi @HobbitLong Thanks for the wonderful repo.

    I am currently using a machine with 2 gpus (1080) to reproduces the SimCLR and SupConstrast on cifar10. Could you please tell me how much time it would take for both methods until their convergence? I had some terrible experience of waiting for weeks for some parallel approaches.

    Thanks for the information. Jizong

    opened by jizongFox 5
  • Support custom dataset

    Support custom dataset

    Hi,

    I wanted to test out simclr and supcontrast on custom datasets and see how they perform. In the PR I made minimal changes to support passing in a path to the dataset, mean, and std to initialize a PyTorch ImageFolder dataset.

    Thanks @zlapp

    opened by zlapp 4
  • base_temperature

    base_temperature

    Thanks for sharing the code!

    What is the intuition behind base_temperature? I found that its value is always 0.07. Is it just for scaling the learning rate by temperature / 0.07? https://github.com/HobbitLong/SupContrast/blob/master/losses.py#L95

    opened by kibok90 4
  • about SupConLoss

    about SupConLoss

    appreciate of your great work! I wanna figrue out contrast_count = features.shape[1],the value of fetures.shape[1] is n_views,and I can't understand what n_views represent.Is it always eaual to 2?

    opened by Dingkx9 3
  • Is there a big difference between one-stage and two-stage training?

    Is there a big difference between one-stage and two-stage training?

    I am a beginner in contrastive learning. I have one question.

    Why isn't supervised contrastive loss trained as a regular term with a classifier(with cross-entropy)? What are the benefits of this two-stage training compared with one-stage training?

    Look forward to your reply.

    opened by TianQi-777 3
  • About Training

    About Training

    I ran this code on my lab computer . but got the error like following: UserWarning: There is an imbalance between your GPUs. You may want to exclude GPU 4 which has less than 75% of the memory or cores of GPU 0. You can do so by setting the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES environment variable. warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos])) terminate called after throwing an instance of 'std::runtime_error' what(): NCCL Error 1: unhandled cuda error I am very upset and confused. Anyone can help me ?

    opened by xianxuan-z 2
  • # of Batch

    # of Batch

    Hi. Thanks for your great research and implementation. However, I am not sure how to use this loss function in a multi-batch situation. I mean, instead of using 2 batches, I am trying to use 4 batches.

    Is it right that I just pass torch.cat([b1, b2, b3, b4, b5])? with the class label for b1?

    Thanks a lot.

    opened by wjun0830 2
  •   File

    File "main_supcon.py", line 270, in main opt = parse_option() File "main_supcon.py", line 94, in parse_option assert opt.degrees is not None AssertionError

    Hi, Thank you for your work. I want to redo your work with my customize dataset. I allowed the guidance ur provided, but I still face this error. I dun know why. Can you please tell me why I am still facing this issue?

    opened by noreenanwar 0
  • the mean and std of cifar dataset

    the mean and std of cifar dataset

    Hi, Thanks for your excellent work and code! I have a question about the mean and std of cifar dataset. For example, for CIFAR-10 dataset, why you use transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), instead of transforms.Normalize([x/255.0 for x in [125.3, 123.0, 113.9]], [x/255.0 for x in [63.0, 62.1, 66.7]])? I am looking forward your replay! Thanks!

    opened by FengShuai-bupt 0
  • Loss function is not convergent when batch-sizes smaller?

    Loss function is not convergent when batch-sizes smaller?

    Hello and thank you for your work. I have some small questions about training. Because I only have one GPU, model 2080ti, the display inch is only 8g, can not run the parameters you set, so I can only adjust the batch-size, but I found that S ran 100 epochs on Cifar-10 loss and no signs of convergence, verification is only 30~% accuracy, I want to know why? Simple AlexNet can also reach 80% acc using CELoss on several epochs。 Is it my Learning rate set wrong? I haven't changed according to you offer 0.5. If so, how much should I set when the batch size is smaller? Or is my epochs not enough? Thank you for your answer!

    opened by kyre-99 0
  • Why use two anchors?

    Why use two anchors?

    Dear author,

    I dont understand why we should use two anchors? what is the benefit of using two anchors?

        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count    // here, in "all" mode, the anchor_count  is 2
    

    May I ask the reason why. Thank you

    opened by qiaoyu1002 0
  • Supervised Contrastive Learning with n_views=1

    Supervised Contrastive Learning with n_views=1

    Hello, First, thank you for this paper and code! I want to adapt you approach to use Supervised Contrastive Learning after a regional proposal network.

    In this context, I was wondering, as it's not explicitly said in the paper: what are your exact motivations for including two views of the images in each batch? From my understanding/intuition, it's to ensure that the anchor is exposed to a "decent"/minimum number of positive samples, is that right?

    Thank you!

    opened by piconti 0
Owner
Yonglong Tian
CS Ph.D. student in AI @ MIT
Yonglong Tian
Unofficial PyTorch implementation of SimCLR by Google Brain

Unofficial PyTorch implementation of SimCLR by Google Brain

Rishabh Anand 2 Oct 13, 2021
Saeed Lotfi 28 Dec 12, 2022
Self-Supervised Contrastive Learning of Music Spectrograms

Self-Supervised Music Analysis Self-Supervised Contrastive Learning of Music Spectrograms Dataset Songs on the Billboard Year End Hot 100 were collect

null 27 Dec 10, 2022
Supervised Contrastive Learning for Downstream Optimized Sequence Representations

SupCL-Seq ?? Supervised Contrastive Learning for Downstream Optimized Sequence representations (SupCS-Seq) accepted to be published in EMNLP 2021, ext

Hooman Sedghamiz 18 Oct 21, 2022
Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive Learning.

Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive Learning. Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive

HLT@HIT(SZ) 7 Dec 16, 2021
Supervised Contrastive Learning for Product Matching

Contrastive Product Matching This repository contains the code and data download links to reproduce the experiments of the paper "Supervised Contrasti

Web-based Systems Group @ University of Mannheim 18 Dec 10, 2022
ALBERT-pytorch-implementation - ALBERT pytorch implementation

ALBERT-pytorch-implementation developing... 모델의 개념이해를 돕기 위한 구현물로 현재 변수명을 상세히 적었고

BG Kim 3 Oct 6, 2022
Fang Zhonghao 13 Nov 19, 2022
An essential implementation of BYOL in PyTorch + PyTorch Lightning

Essential BYOL A simple and complete implementation of Bootstrap your own latent: A new approach to self-supervised Learning in PyTorch + PyTorch Ligh

Enrico Fini 48 Sep 27, 2022
RealFormer-Pytorch Implementation of RealFormer using pytorch

RealFormer-Pytorch Implementation of RealFormer using pytorch. Includes comparison with classical Transformer on image classification task (ViT) wrt C

Simo Ryu 90 Dec 8, 2022
A PyTorch implementation of the paper Mixup: Beyond Empirical Risk Minimization in PyTorch

Mixup: Beyond Empirical Risk Minimization in PyTorch This is an unofficial PyTorch implementation of mixup: Beyond Empirical Risk Minimization. The co

Harry Yang 121 Dec 17, 2022
A pytorch implementation of Pytorch-Sketch-RNN

Pytorch-Sketch-RNN A pytorch implementation of https://arxiv.org/abs/1704.03477 In order to draw other things than cats, you will find more drawing da

Alexis David Jacq 172 Dec 12, 2022
PyTorch implementation of Advantage async actor-critic Algorithms (A3C) in PyTorch

Advantage async actor-critic Algorithms (A3C) in PyTorch @inproceedings{mnih2016asynchronous, title={Asynchronous methods for deep reinforcement lea

LEI TAI 111 Dec 8, 2022
Pytorch-diffusion - A basic PyTorch implementation of 'Denoising Diffusion Probabilistic Models'

PyTorch implementation of 'Denoising Diffusion Probabilistic Models' This reposi

Arthur Juliani 76 Jan 7, 2023
RETRO-pytorch - Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch

RETRO - Pytorch (wip) Implementation of RETRO, Deepmind's Retrieval based Attent

Phil Wang 556 Jan 4, 2023
HashNeRF-pytorch - Pure PyTorch Implementation of NVIDIA paper on Instant Training of Neural Graphics primitives

HashNeRF-pytorch Instant-NGP recently introduced a Multi-resolution Hash Encodin

Yash Sanjay Bhalgat 616 Jan 6, 2023
Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pytorch Lightning 1.4k Jan 1, 2023
PyTorch implementation of: Michieli U. and Zanuttigh P., "Continual Semantic Segmentation via Repulsion-Attraction of Sparse and Disentangled Latent Representations", CVPR 2021.

Continual Semantic Segmentation via Repulsion-Attraction of Sparse and Disentangled Latent Representations This is the official PyTorch implementation

Multimedia Technology and Telecommunication Lab 42 Nov 9, 2022
PyTorch implementation of CVPR 2020 paper (Reference-Based Sketch Image Colorization using Augmented-Self Reference and Dense Semantic Correspondence) and pre-trained model on ImageNet dataset

Reference-Based-Sketch-Image-Colorization-ImageNet This is a PyTorch implementation of CVPR 2020 paper (Reference-Based Sketch Image Colorization usin

Yuzhi ZHAO 11 Jul 28, 2022