TorchSSL: A PyTorch-based Toolbox for Semi-Supervised Learning

Overview

TorchSSL: A PyTorch-based Toolbox for Semi-Supervised Learning

An all-in-one toolkit based on PyTorch for semi-supervised learning (SSL). We implmented 9 popular SSL algorithms to enable fair comparison and boost the development of SSL algorithms.

FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling(https://arxiv.org/abs/2110.08263)

Supported algorithms

We support fully supervised training + 9 popular SSL algorithms as listed below:

  1. Pi-Model [1]
  2. MeanTeacher [2]
  3. Pseudo-Label [3]
  4. VAT [4]
  5. MixMatch [5]
  6. UDA [6]
  7. ReMixMatch [7]
  8. FixMatch [8]
  9. FlexMatch [9]

Besides, we implement our Curriculum Pseudo Labeling (CPL) method for Pseudo-Label (Flex-Pseudo-Label) and UDA (Flex-UDA).

Supported datasets

We support 5 popular datasets in SSL research as listed below:

  1. CIFAR-10
  2. CIFAR-100
  3. STL-10
  4. SVHN
  5. ImageNet

Installation

  1. Prepare conda
  2. Run conda env create -f environment.yml

Usage

It is convenient to perform experiment with TorchSSL. For example, if you want to perform FlexMatch algorithm:

  1. Modify the config file in config/flexmatch/flexmatch.yaml as you need
  2. Run python flexmatch --c config/flexmatch/flexmatch.yaml

Customization

If you want to write your own algorithm, please follow the following steps:

  1. Create a directory for your algorithm, e.g., SSL, write your own model file SSl/SSL.py in it.
  2. Write the training file in SSL.py
  3. Write the config file in config/SSL/SSL.yaml

Results

avatar avatar avatar avatar

Citation

If you think this toolkit or the results are helpful to you and your research, please cite our paper:

@article{zhang2021flexmatch},
  title={FlexMatch: Boosting Semi-supervised Learning with Curriculum Pseudo Labeling},
  author={Zhang, Bowen and Wang, Yidong and Hou Wenxin and Wu, Hao and Wang, Jindong and Okumura, Manabu and Shinozaki, Takahiro},
  booktitle={Neural Information Processing Systems (NeurIPS)},
  year={2021}
}

Maintainer

Yidong Wang1, Hao Wu2, Bowen Zhang1, Wenxin Hou1,3, Jindong Wang3

Shinozaki Lab1 http://www.ts.ip.titech.ac.jp/

Okumura Lab2 http://lr-www.pi.titech.ac.jp/wp/

Microsoft Research Asia3

References

[1] Antti Rasmus, Harri Valpola, Mikko Honkala, Mathias Berglund, and Tapani Raiko. Semi-supervised learning with ladder networks. InNeurIPS, pages 3546–3554, 2015.

[2] Antti Tarvainen and Harri Valpola. Mean teachers are better role models: Weight-averagedconsistency targets improve semi-supervised deep learning results. InNeurIPS, pages 1195–1204, 2017.

[3] Dong-Hyun Lee et al. Pseudo-label: The simple and efficient semi-supervised learning methodfor deep neural networks. InWorkshop on challenges in representation learning, ICML,volume 3, 2013.

[4] Takeru Miyato, Shin-ichi Maeda, Masanori Koyama, and Shin Ishii. Virtual adversarial training:a regularization method for supervised and semi-supervised learning.IEEE TPAMI, 41(8):1979–1993, 2018.

[5] David Berthelot, Nicholas Carlini, Ian Goodfellow, Nicolas Papernot, Avital Oliver, and ColinRaffel. Mixmatch: A holistic approach to semi-supervised learning.NeurIPS, page 5050–5060,2019.

[6] Qizhe Xie, Zihang Dai, Eduard Hovy, Thang Luong, and Quoc Le. Unsupervised data augmen-tation for consistency training.NeurIPS, 33, 2020.

[7] David Berthelot, Nicholas Carlini, Ekin D Cubuk, Alex Kurakin, Kihyuk Sohn, Han Zhang,and Colin Raffel. Remixmatch: Semi-supervised learning with distribution matching andaugmentation anchoring. InICLR, 2019.

[8] Kihyuk Sohn, David Berthelot, Nicholas Carlini, Zizhao Zhang, Han Zhang, Colin A Raf-fel, Ekin Dogus Cubuk, Alexey Kurakin, and Chun-Liang Li. Fixmatch: Simplifying semi-supervised learning with consistency and confidence.NeurIPS, 33, 2020.

[9] Bowen Zhang, Yidong Wang, Wenxin Hou, Hao wu, Jindong Wang, Okumura Manabu, and Shinozaki Takahiro. FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling. NeurIPS, 2021.

Comments
  • Ask for the code

    Ask for the code

    Hi, I am interested in your work. research was very interesting. I want to know if 'Flexmatch' in the corresponding code reproduces the results of the paper. I tried CIFAR100_400 case, but the top-1 score doesn't go up.

    I look at the code and have a question related to it. Is it right that classwise_acc all starts at 0.? Doesn't this mean that the threshold of each class is zero? This seems to be a factor in lowering the score by breaking the model. Is it the intention of the paper to raise the threshold from zero? (I'm sorry if I didn't understand.)

    Is there any other way to train the model properly?

    Best Regards, Harim

    opened by harimkang 22
  • Release of training log or tensorboard for FlexMatch on CIFAR100

    Release of training log or tensorboard for FlexMatch on CIFAR100

    Dear authors,

    Thanks for the impressive work FlexMatch and the awesome codebase!

    Recently, I am trying to conduct some experiments on CIFAR100 based on this repo and I just found that FlexMatch training may cost around 50min for 5k iterations (3 * NVIDIA RTX2080Ti) which may take 7 days for the whole training. It is quite a long period.

    I wonder if it is possible for you to release the training log or tensorboard for FlexMatch method on CIFAR100 to provide me a reference which would help me a lot. Many many thanks~

    Cheers, Haiming

    opened by HeimingX 10
  • tensorboard not working

    tensorboard not working

    Hello, and many thanks for creating this repo!

    I've run several experiments, the log gives the expected results but when i run tensorboard I consistently get "No dashboards active". Am I doing something wrong?

    opened by nikoskaraliolios 8
  • Some questions about multi-gpu training

    Some questions about multi-gpu training

    Firstly, thanks a lot for open source this code base which will help the development of semi-supervised field. I have questions about multi-gpu training and look forward to your reply: As far as I understand, this code supports multi-gpu training and I want to know have you ever tested different parameters in multi-gpu environment, for example, increasing batch size when using more gpu. And how did you set the parameters in .yaml when you have multiple gpus in one machine?

    opened by wanghao14 7
  • Reproduce Numbers on CIFAR100

    Reproduce Numbers on CIFAR100

    Hi,

    Thanks for the great work. I tried to reproduce FlexMatch number on CIFAR100 with 400 labels. I followed the instructions to create a conda env and ran python flexmatch.py --c config/flexmatch/flexmatch_cifar100_400_0.yaml with 3 Tesla V100.

    Unfortunately, the best top-1 accuracy I got was 60.65, which is 6.91% lower than the reported number in the paper. There seems to be a sharp performance drop at around iteration 600K. But I couldn't pinpoint the issue from the training statistics. I wonder if you also have observed similar behaviors. It would be great if you could offer some insights here. Thanks! image

    image

    Here is the tensorboard file: tf_logs.zip

    Besides, the curve of the mask ratio looks a bit strange to me. Because 1.0 - mask.detach() is actually logged in the code, so, shouldn't it start from 1 and then decrease? Any intuition why it starts from 0 and increases at the very beginning? Thanks!

    opened by YUE-FAN 6
  • Results in the paper

    Results in the paper

    Thanks for your great work!

    Just one quick question, how many GPUs were used to obtain the results in the paper? I didn't seem to find the specification on this.

    Best

    opened by ZhuoranYu 5
  • DataLoader worker (pid 12847) is killed by signal: Killed

    DataLoader worker (pid 12847) is killed by signal: Killed

    When I run python flexmatch.py --c ./config/flexmatch/flexmatch_cifar100_400_1.yaml, I always get the following error. Can you help me to fix the mistake?

    1635170601

    opened by ljjcoder 5
  • AttrubuteError

    AttrubuteError

    I put the custom data set using the same data loader as Imagenet, and organize my data as: "imagenet"/{train or val}/class name/*.jpg
    image

    But this situation still appears, how should I solve it?
    image

    and there is the [isic.yaml.] setting. isic.txt

    opened by tangwwwwww 4
  • supervised to semi-supervised ?

    supervised to semi-supervised ?

    Hi, I just discovered your repo, its great ! However, I was wondering if it was possible to modify for example a supervised training loop that i have to do semi-supervised learning ? How would one adapt and use your repo to achieve such task ? for example if i am currently using detectron2 module to do training, is there a way i can modify train.py to do semi-supervised learning ?

    Thank you so much !

    opened by an99990 4
  • Validation set for CIFAR10

    Validation set for CIFAR10

    Thanks for your great work!

    But it seems there is no validation set sampling part in this repository.

    As far as I know, in general, 5,000 images are subtracted from the training set and used as validation. [1] Official FixMatch, validation set size options, fixmatch [2] A. Oliver, et al., "Realistic Evaluation of Deep Semi-Supervised Learning Algorithms"

    I tried to find it in the datasets/ssl_dataset.py, but I couldn't. May I know where the validation set part, or why it's not there? Thanks!

    opened by Holim0711 4
  • Regarding the warmup in the snippet.

    Regarding the warmup in the snippet.

    Hi, Thanks a lot for your contribution.

    I wanted to ask about the warmup thing. In your flexmatch training code there is args.warmup inside the if condition if max(pseudo_counter.values()) < len(self.ulb_dset):

    Will it not everytime go in this if condition and args.warmup if condition? Because I printed the if statement and through the training it goes in args.warmup if condition.

    So ACC to paper the warmup should get over after some time? Is there a issue in a code or am I understanding it in a wrong way?

    Thanks a lot, Shreejal

    opened by shreejalt 3
  • Fix the save_model() bug.

    Fix the save_model() bug.

    I think there's a small bug when saving the model.

    The model is saved before the self.it is updated, so when training is resumed, the model starts with the same self.it. However, it should be self.it + 1

    https://github.com/TorchSSL/TorchSSL/blob/f2f46076cbea1b6f6c9b3c1c45609502c6576250/models/fixmatch/fixmatch.py#L191-L196

    https://github.com/TorchSSL/TorchSSL/blob/f2f46076cbea1b6f6c9b3c1c45609502c6576250/models/fixmatch/fixmatch.py#L220

    In my workaround, take note that you can only use save_model() after updating the model and before updating self.it.

    opened by PM25 0
  • Drop in accuracy on resumt

    Drop in accuracy on resumt

    I am running this code on a server that has a time limit of one day per job. So, I need to resume the code. I see a drop in accuracy when the training is resuming. Could you please comment on what might be causing such drops?

    opened by ShuvenduRoy 4
  • selected_label

    selected_label

    Hello, I read the code you open source to Github and found that the size of selected_label is the size of the entire unlabeled dataset. When updating selected_label later, the index used is the current batch size index, for example, batch is 64, X_ULb_idx is the index range of 0-63. Or the index of the original dataset, I don't know if you can understand this description?

    opened by lzw1997lzw 1
  • question about different version of code.

    question about different version of code.

    Thanks for your excellent work. I found that the results in cifar100 are different in different versions of paper. I noticed that your code also changed five months ago. I downloaded your code on 2021.10.19 and experimented on it. Is the code of 2021.10.19 different from the current one? Do I need to redo the experiment on the current code?

    opened by ljjcoder 1
Owner
null
A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch

Torchmeta A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch. Torchmeta contains popular meta-learning bench

Tristan Deleu 1.7k Jan 6, 2023
An implementation of Performer, a linear attention-based transformer, in Pytorch

Performer - Pytorch An implementation of Performer, a linear attention-based transformer variant with a Fast Attention Via positive Orthogonal Random

Phil Wang 900 Dec 22, 2022
null 270 Dec 24, 2022
A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

Fidelity Investments 56 Sep 13, 2022
A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.

null 878 Dec 30, 2022
Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Martin Krasser 251 Dec 25, 2022
PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

Cong Cai 12 Dec 19, 2021
The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.

News March 3: v0.9.97 has various bug fixes and improvements: Bug fixes for NTXentLoss Efficiency improvement for AccuracyCalculator, by using torch i

Kevin Musgrave 5k Jan 2, 2023
PyGCL: Graph Contrastive Learning Library for PyTorch

PyGCL is an open-source library for graph contrastive learning (GCL), which features modularized GCL components from published papers, standardized evaluation, and experiment management.

GCL: Graph Contrastive Learning Library for PyTorch 592 Jan 7, 2023
A PyTorch implementation of Learning to learn by gradient descent by gradient descent

Intro PyTorch implementation of Learning to learn by gradient descent by gradient descent. Run python main.py TODO Initial implementation Toy data LST

Ilya Kostrikov 300 Dec 11, 2022
Pretrained ConvNets for pytorch: NASNet, ResNeXt, ResNet, InceptionV4, InceptionResnetV2, Xception, DPN, etc.

Pretrained models for Pytorch (Work in progress) The goal of this repo is: to help to reproduce research papers results (transfer learning setups for

Remi 8.7k Dec 31, 2022
Model summary in PyTorch similar to `model.summary()` in Keras

Keras style model.summary() in PyTorch Keras has a neat API to view the visualization of the model which is very helpful while debugging your network.

Shubham Chandel 3.7k Dec 29, 2022
torch-optimizer -- collection of optimizers for Pytorch

torch-optimizer torch-optimizer -- collection of optimizers for PyTorch compatible with optim module. Simple example import torch_optimizer as optim

Nikolay Novik 2.6k Jan 3, 2023
A PyTorch implementation of EfficientNet

EfficientNet PyTorch Quickstart Install with pip install efficientnet_pytorch and load a pretrained EfficientNet with: from efficientnet_pytorch impor

Luke Melas-Kyriazi 7.2k Jan 6, 2023
PyTorch Extension Library of Optimized Scatter Operations

PyTorch Scatter Documentation This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations fo

Matthias Fey 1.2k Jan 7, 2023
PyTorch Extension Library of Optimized Autograd Sparse Matrix Operations

PyTorch Sparse This package consists of a small extension library of optimized sparse matrix operations with autograd support. This package currently

Matthias Fey 757 Jan 4, 2023
Reformer, the efficient Transformer, in Pytorch

Reformer, the Efficient Transformer, in Pytorch This is a Pytorch implementation of Reformer https://openreview.net/pdf?id=rkgNKkHtvB It includes LSH

Phil Wang 1.8k Jan 6, 2023
higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.

higher is a library providing support for higher-order optimization, e.g. through unrolled first-order optimization loops, of "meta" aspects of these

Facebook Research 1.5k Jan 3, 2023
PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf

README TabNet : Attentive Interpretable Tabular Learning This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attent

DreamQuark 2k Dec 27, 2022