Implementation of Supervised Contrastive Learning with AMP, EMA, SWA, and many other tricks

Overview

SupCon-Framework

The repo is an implementation of Supervised Contrastive Learning. It's based on another implementation, but with several differencies:

  • Fixed bugs (incorrect ResNet implementations, which leads to a very small max batch size),
  • Offers a lot of additional functionality (first of all, rich validation).

To be more precise, in this implementations you will find:

  • Augmentations with albumentations
  • Hyperparameters are moved to .yml configs
  • t-SNE visualizations
  • 2-step validation (for features before and after the projection head) using metrics like AMI, NMI, mAP, precision_at_1, etc with PyTorch Metric Learning.
  • Exponential Moving Average for a more stable training, and Stochastic Moving Average for a better generalization and just overall performance.
  • Automatic Mixed Precision (torch version) training in order to be able to train with a bigger batch size (roughly by a factor of 2).
  • LabelSmoothing loss, and LRFinder for the second stage of the training (FC).
  • TensorBoard logs, checkpoints
  • Support of timm models, and pytorch-optimizer

Install

  1. Clone the repo:
git clone https://github.com/ivanpanshin/SupCon-Framework && cd SupCon-Framework/
  1. Create a clean virtual environment
python3 -m venv venv
source venv/bin/activate
  1. Install dependencies
python -m pip install --upgrade pip
pip install -r requirements.txt

Training

In order to execute Cifar10 training run:

python train.py --config_name configs/train/train_supcon_resnet18_cifar10_stage1.yml
python swa.py --config_name configs/train/swa_supcon_resnet18_cifar10_stage1.yml
python train.py --config_name configs/train/train_supcon_resnet18_cifar10_stage2.yml
python swa.py --config_name configs/train/swa_supcon_resnet18_cifar10_stage2.yml

In order to run LRFinder on the second stage of the training, run:

python learning_rate_finder.py --config_name configs/train/lr_finder_supcon_resnet18_cifar10_stage2.yml

The process of training Cifar100 is exactly the same, just change config names from cifar10 to cifar100.

After that you can check the results of the training either in logs or runs directory. For example, in order to check tensorboard logs for the first stage of Cifar10 training, run:

tensorboard --logdir runs/supcon_first_stage_cifar10

Visualizations

This repo is supplied with t-SNE visualizations so that you can check embeddings you get after the training. Check t-SNE.ipynb for details.

Those are t-SNE visualizations for Cifar10 for validation and train with SupCon (top), and validation and train with CE (bottom).

Those are t-SNE visualizations for Cifar100 for validation and train with SupCon (top), and validation and train with CE (bottom).

Results

Model Stage Dataset Accuracy
ResNet18 Frist CIFAR10 95.9
ResNet18 Second CIFAR10 94.9
ResNet18 Frist CIFAR100 79.0
ResNet18 Second CIFAR100 77.9

Note that even though the accuracy on the second stage is lower, it's not always the case. In my experience, the difference between stages is usually around 1 percent, including the difference that favors the second stage.

Training time for the whole pipeline (without any early stopping) on CIFAR10 or CIFAR100 is around 4 hours (single 2080Ti with AMP). However, with reasonable early stopping that value goes down to around 2.5-3 hours.

Custom datasets

It's fairly easy to adapt this pipeline to custom datasets. First, you need to check tools/datasets.py for that. Second, add a new class for your dataset. The only guideline here is to follow the same augmentation logic, that is

        if self.second_stage:
            image = self.transform(image=image)['image']
        else:
            image = self.transform(image)

Third, add your dataset to DATASETS dict still inside tools/datasets.py, and you're good to go.

FAQ

  • Q: What hyperparameters I should try to change?

    A: First of all, learning rate. Second of all, try to change the augmentation policy. SupCon is build around "cropping + color jittering" scheme, so you can try changing the cropping size or the intensity of jittering. Check tools.utils.build_transforms for that.

  • Q: What backbone and batch size should I use?

    A: This is quite simple. Take the biggest backbone you can, and after that take the highest batch size your GPU can offer. The reason for that: SupCon is more prone (than regular classification training with CE/LabelSmoothing/etc) to improving with stronger backbones. Moverover, it has a property of explicit hard positive and negative mining. It means that the higher the batch size - the more difficult and helpful samples you supply to your model.

  • Q: Do I need the second stage of the training?

    A: Not necessarily. You can do classification based only on embeddings. In order to do that compute embeddings for the train set, and at inference time do the following: take a sample, compute its embedding, take the closest one from the training, take its class. To make this fast and efficient, you something like faiss for similarity search. Note that this is actually how validation is done in this repo. Moveover, during training you will see a metric precision_at_1. This is actually just accuracy based solely on embeddings.

  • Q: Should I use AMP?

    A: If your GPU has tensor cores (like 2080Ti) - yes. If it doesn't (like 1080Ti) - check the speed with AMP and without. If the speed dropped slightly (or even increased by a bit) - use it, since SupCon works better with bigger batch sizes.

  • Q: How should I use EMA?

    A: You only need to choose the ema_decay_per_epoch parameter in the config. The heuristic is fairly simple. If your dataset is big, then something as small as 0.3 will do just fine. And as your dataset gets smaller, you can increase ema_decay_per_epoch. Thanks to bonlime for this idea. I advice you to check his great pytorch tools repo, it's a hidden gem.

  • Q: Is it better than training with Cross Entropy/Label Smoothing/etc?

    A: Unfortunately, in my experience, it's much easier to get better results with something like CE. It's more stable, faster to train, and simply produces better or the same results. For instance, in case on CIFAR10/100 it's trivial to train ResNet18 up tp 96/81 percent respectively. Of cource, I've seen cased where SupCon performs better, but it takes quite a bit of work to make it outperform CE.

  • Q: How long should I train with SupCon?

    A: The answer is tricky. On one hand, authors of the original paper claim that the longer you train with SupCon, the better it gets. However, I did not observe such a behavior in my tests. So the only recommendation I can give is the following: start with 100 epochs for easy datasets (like CIFAR10/100), and 1000 for more industrial ones. Then - monitor the training process. If the validaton metric (such as precision_at_1) doesn't impove for several dozens of epochs - you can stop the training. You might incorporate early stopping for this reason into the pipeline.

Comments
  • ValueError in Stage2

    ValueError in Stage2

    Hi, I'm trying to run your code on CIFAR-10. The training and SWA in stage1 were fine, but I got the following error when training stage2:

    root@864d7f9c24b4:/SupCon-Framework# python train.py --config_name configs/train/train_supcon_resnet18_cifar10_stage2.yml 
    {'model': {'backbone': 'resnet18', 'ckpt_pretrained': 'weights/supcon_first_stage_cifar10/swa', 'num_classes': 10}, 'train': {'n_epochs': 20, 'amp': True, 'ema': True, 'ema_decay_per_epoch': 0.3, 'logging_name': 'supcon_second_stage_cifar10', 'target_metric': 'accuracy', 'stage': 'second'}, 'dataset': 'data/cifar10', 'dataloaders': {'train_batch_size': 20, 'valid_batch_size': 20, 'num_workers': 12}, 'optimizer': {'name': 'SGD', 'params': {'lr': 0.01}}, 'scheduler': {'name': 'CosineAnnealingLR', 'params': {'T_max': 20, 'eta_min': 0.001}}, 'criterion': {'name': 'LabelSmoothing', 'params': {'classes': 10, 'smoothing': 0.01}}}
    Files already downloaded and verified
    Files already downloaded and verified
    Traceback (most recent call last):
      File "train.py", line 111, in <module>
        train_metrics = utils.train_epoch_ce(loaders['train_features_loader'], model, criterion, optimizer, scaler, ema)
      File "/SupCon-Framework/tools/utils.py", line 250, in train_epoch_ce
        ema.update(model.parameters())
      File "/usr/local/lib/python3.8/dist-packages/torch_ema/ema.py", line 88, in update
        parameters = self._get_parameters(parameters)
      File "/usr/local/lib/python3.8/dist-packages/torch_ema/ema.py", line 65, in _get_parameters
        raise ValueError(
    ValueError: Number of parameters passed as argument is different from number of shadow parameters maintained by this ExponentialMovingAverage
    

    Another minor problem is GPU usage. I used to run another implementation of SupContrast. It requires 8x GPU memory (and higher utilization of each GPU) to train stage1 of the same backbone and batch size. Did your know what cause that difference?

    opened by JiarunLiu 4
  • Train in 2 steps instead of 4?

    Train in 2 steps instead of 4?

    Hi thanks for your neat work here. It looks like there are currently 4 proposed steps to train: for stage 1: train.py, swa.py and for stage 2: train.py, swa.py.

    Is there a way to pre-train using supcon loss once, and classification training once, for a total of two steps?

    opened by ibarrien 3
  • can supCon loss used in multi-label classification?

    can supCon loss used in multi-label classification?

    I have a text multi-label classification task,can i use supCon loss ? supCon loss is accumulated by every label view,for example: batch data label = [[1, 0, 1], [0, 1, 1], [1, 1, 0], [0, 1, 1] ] from view label 0, positive examples = {0, 2},negative samples = {1, 3} from view label 1, positive examples = {1, 2,3}, negative samples = {0} from view label 2, positive examples = {0, 1, 2}, negative samples = {2}

    is this setting here reasonable ?

    opened by littttttlebird 2
  • Run custom dataset

    Run custom dataset

    Hello everyone, I am trying to run my own dataset, but I face the same error all the time.

    Traceback (most recent call last): File "train.py", line 64, in loaders = utils.build_loaders(data_dir, transforms, batch_sizes, num_workers, second_stage=(stage == 'second')) File "/home/karantai/SupCon-Framework/tools/utils.py", line 93, in build_loaders transform=transforms['valid_transforms'], second_stage=True) File "/home/karantai/SupCon-Framework/tools/datasets.py", line 79, in create_supcon_dataset return DATASETS[dataset_name](data_dir, train, transform, second_stage) File "/home/karantai/SupCon-Framework/tools/datasets.py", line 50, in init super().init(root=data_dir, train=train, download=False, transform=transform) File "/home/karantai/anaconda3/envs/myenv/lib/python3.6/site-packages/torchvision/datasets/cifar.py", line 69, in init ' You can use download=True to download it') RuntimeError: Dataset not found or corrupted. You can use download=True to download it

    I have set Download = False. And I have created a pickled dict with my data as the cifar 10 is. Maybe my problem is trivial, but I would appreciate any help in utils.dataset.py modification or the dataset's layout.

    Thank you in advance guys!

    opened by geokarant 1
  • SupCon loss

    SupCon loss

    @ivanpanshin Thanks for the wonderful implementation. Does the fact that the SupCon loss can be seen as a generalization of the triplet mean that we should expect the embeddings to lie either in a euclidean space or low dimensional manifold?

    opened by agporto 0
  • Suggest to loosen the dependency on albumentations

    Suggest to loosen the dependency on albumentations

    Hi, your project SupCon-Framework requires "albumentations==0.5.2" in its dependency. After analyzing the source code, we found that the following versions of albumentations can also be suitable without affecting your project, i.e., albumentations 0.5.1. Therefore, we suggest to loosen the dependency on albumentations from "albumentations==0.5.2" to "albumentations>=0.5.1,<=0.5.2" to avoid any possible conflict for importing more packages or for downstream projects that may use SupCon-Framework.

    May I pull a request to further loosen the dependency on albumentations?

    By the way, could you please tell us whether such dependency analysis may be potentially helpful for maintaining dependencies easier during your development?



    We also give our detailed analysis as follows for your reference:

    Your project SupCon-Framework directly uses 8 APIs from package albumentations.

    albumentations.augmentations.transforms.Resize.__init__, albumentations.core.composition.Compose.__init__, albumentations.augmentations.transforms.RandomResizedCrop.__init__, albumentations.pytorch.transforms.ToTensorV2.__init__, albumentations.augmentations.transforms.ToGray.__init__, albumentations.augmentations.transforms.Rotate.__init__, albumentations.augmentations.transforms.ColorJitter.__init__, albumentations.augmentations.transforms.Normalize.__init__
    
    

    Beginning from the 8 APIs above, 14 functions are then indirectly called, including 14 albumentations's internal APIs and 0 outsider APIs. The specific call graph is listed as follows (neglecting some repeated function occurrences).

    [/ivanpanshin/SupCon-Framework]
    +--albumentations.augmentations.transforms.Resize.__init__
    |      +--albumentations.core.transforms_interface.BasicTransform.__init__
    +--albumentations.core.composition.Compose.__init__
    |      +--albumentations.core.composition.BaseCompose.__init__
    |      |      +--albumentations.core.composition.Transforms.__init__
    |      |      |      +--albumentations.core.composition.Transforms._find_dual_start_end
    |      |      |      |      +--albumentations.core.composition.Transforms._find_dual_start_end
    |      +--albumentations.augmentations.bbox_utils.BboxProcessor.__init__
    |      |      +--albumentations.core.utils.DataProcessor.__init__
    |      +--albumentations.core.composition.BboxParams.__init__
    |      |      +--albumentations.core.utils.Params.__init__
    |      +--albumentations.augmentations.keypoints_utils.KeypointsProcessor.__init__
    |      |      +--albumentations.core.utils.DataProcessor.__init__
    |      +--albumentations.core.composition.KeypointParams.__init__
    |      |      +--albumentations.core.utils.Params.__init__
    |      +--albumentations.core.composition.BaseCompose.add_targets
    +--albumentations.augmentations.transforms.RandomResizedCrop.__init__
    |      +--albumentations.augmentations.transforms._BaseRandomSizedCrop.__init__
    |      |      +--albumentations.core.transforms_interface.BasicTransform.__init__
    +--albumentations.pytorch.transforms.ToTensorV2.__init__
    |      +--albumentations.core.transforms_interface.BasicTransform.__init__
    +--albumentations.augmentations.transforms.ToGray.__init__
    |      +--albumentations.core.transforms_interface.BasicTransform.__init__
    +--albumentations.augmentations.transforms.Rotate.__init__
    |      +--albumentations.core.transforms_interface.BasicTransform.__init__
    |      +--albumentations.core.transforms_interface.to_tuple
    +--albumentations.augmentations.transforms.ColorJitter.__init__
    |      +--albumentations.core.transforms_interface.BasicTransform.__init__
    |      +--albumentations.augmentations.transforms.ColorJitter.__check_values
    +--albumentations.augmentations.transforms.Normalize.__init__
    |      +--albumentations.core.transforms_interface.BasicTransform.__init__
    

    We scan albumentations's versions and observe that during its evolution between any version from [0.5.1] and 0.5.2, the changing functions (diffs being listed below) have none intersection with any function or API we mentioned above (either directly or indirectly called by this project).

    diff: 0.5.2(original) 0.5.1
    ['albumentations.augmentations.transforms.MedianBlur', 'albumentations.augmentations.transforms.CropNonEmptyMaskIfExists.targets_as_params', 'albumentations.augmentations.transforms.GaussianBlur', 'albumentations.augmentations.transforms.CropNonEmptyMaskIfExists.update_params', 'albumentations.pytorch.transforms.ToTensorV2', 'albumentations.pytorch.transforms.ToTensorV2.apply', 'albumentations.augmentations.transforms.CropNonEmptyMaskIfExists.get_params_dependent_on_targets', 'albumentations.augmentations.transforms.CropNonEmptyMaskIfExists', 'albumentations.augmentations.transforms.CropNonEmptyMaskIfExists._preprocess_mask']
    
    

    Therefore, we believe that it is quite safe to loose your dependency on albumentations from "albumentations==0.5.2" to "albumentations>=0.5.1,<=0.5.2". This will improve the applicability of SupCon-Framework and reduce the possibility of any further dependency conflict with other projects.

    opened by Agnes-U 0
Owner
Ivan Panshin
Machine Learning Engineer: CV, NLP, tabular data. Kaggle (top 0.003% worldwide) and Open Source
Ivan Panshin
Django Rest Framework App wih JWT Authentication and other DRF stuff

Django Queries App with JWT authentication, Class Based Views, Serializers, Swagger UI, CI/CD and other cool DRF stuff API Documentaion /swagger - Swa

Rafael Salimov 4 Jan 29, 2022
Official implementation of the AAAI 2022 paper "Learning Token-based Representation for Image Retrieval"

Token: Token-based Representation for Image Retrieval PyTorch training code for Token-based Representation for Image Retrieval. We propose a joint loc

Hui Wu 42 Dec 6, 2022
Easy and secure implementation of Azure AD for your FastAPI APIs 🔒 Single- and multi-tenant support.

Easy and secure implementation of Azure AD for your FastAPI APIs ?? Single- and multi-tenant support.

Intility 220 Jan 5, 2023
Flask Implementation of a login page and some basic functionality.

login_page Flask Implementation of a login page and some basic functionality. How to Run $ chmod +x run.sh setup.sh $ # run setup.sh only if the datab

null 3 Jun 3, 2021
A generic, spec-compliant, thorough implementation of the OAuth request-signing logic

OAuthLib - Python Framework for OAuth1 & OAuth2 *A generic, spec-compliant, thorough implementation of the OAuth request-signing logic for Python 3.5+

OAuthlib 2.5k Jan 2, 2023
JSON Web Token implementation in Python

PyJWT A Python implementation of RFC 7519. Original implementation was written by @progrium. Sponsor If you want to quickly add secure token-based aut

José Padilla 4.5k Jan 9, 2023
A JOSE implementation in Python

python-jose A JOSE implementation in Python Docs are available on ReadTheDocs. The JavaScript Object Signing and Encryption (JOSE) technologies - JSON

Michael Davis 1.2k Dec 28, 2022
A generic, spec-compliant, thorough implementation of the OAuth request-signing logic

OAuthLib - Python Framework for OAuth1 & OAuth2 *A generic, spec-compliant, thorough implementation of the OAuth request-signing logic for Python 3.5+

OAuthlib 2.5k Jan 1, 2023
REST implementation of Django authentication system.

djoser REST implementation of Django authentication system. djoser library provides a set of Django Rest Framework views to handle basic actions such

Sunscrapers 2.2k Jan 1, 2023
Simple implementation of authentication in projects using FastAPI

Fast Auth Facilita implementação de um sistema de autenticação básico e uso de uma sessão de banco de dados em projetos com tFastAPi. Instalação e con

null 3 Jan 8, 2022
Toolkit for Pyramid, a Pylons Project, to add Authentication and Authorization using Velruse (OAuth) and/or a local database, CSRF, ReCaptcha, Sessions, Flash messages and I18N

Apex Authentication, Form Library, I18N/L10N, Flash Message Template (not associated with Pyramid, a Pylons project) Uses alchemy Authentication Authe

null 95 Nov 28, 2022
This app makes it extremely easy to build Django powered SPA's (Single Page App) or Mobile apps exposing all registration and authentication related functionality as CBV's (Class Base View) and REST (JSON)

Welcome to django-rest-auth Repository is unmaintained at the moment (on pause). More info can be found on this issue page: https://github.com/Tivix/d

Tivix 2.4k Jan 3, 2023
A host-guest based app in which host can CREATE the room. and guest can join room with room code and vote for song to skip. User is authenticated using Spotify API

A host-guest based app in which host can CREATE the room. and guest can join room with room code and vote for song to skip. User is authenticated using Spotify API

Aman Raj 5 May 10, 2022
A full Rest-API With Oauth2 and JWT for request & response a JSON file Using FastAPI and SQLAlchemy 🔑

Pexon-Rest-API A full Rest-API for request & response a JSON file, Building a Simple WorkFlow that help you to Request a JSON File Format and Handling

Yasser Tahiri 15 Jul 22, 2022
Two factor authentication system using azure services and python language and its api's

FUTURE READY TALENT VIRTUAL INTERSHIP PROJECT PROJECT NAME - TWO FACTOR AUTHENTICATION SYSTEM Resources used: * Azure functions(python)

BHUSHAN SATISH DESHMUKH 1 Dec 10, 2021
Mock authentication API that acceccpts email and password and returns authentication result.

Mock authentication API that acceccpts email and password and returns authentication result.

Herman Shpryhau 1 Feb 11, 2022
The ultimate Python library in building OAuth, OpenID Connect clients and servers. JWS,JWE,JWK,JWA,JWT included.

Authlib The ultimate Python library in building OAuth and OpenID Connect servers. JWS, JWK, JWA, JWT are included. Authlib is compatible with Python2.

Hsiaoming Yang 3.4k Jan 4, 2023
A fully tested, abstract interface to creating OAuth clients and servers.

Note: This library implements OAuth 1.0 and not OAuth 2.0. Overview python-oauth2 is a python oauth library fully compatible with python versions: 2.6

Joe Stump 3k Jan 2, 2023
Python module for generating and verifying JSON Web Tokens

python-jwt Module for generating and verifying JSON Web Tokens. Note: From version 2.0.1 the namespace has changed from jwt to python_jwt, in order to

David Halls 210 Dec 24, 2022