A state-of-the-art semi-supervised method for image recognition

Overview

Mean teachers are better role models

Paper ---- NIPS 2017 poster ---- NIPS 2017 spotlight slides ---- Blog post

By Antti Tarvainen, Harri Valpola (The Curious AI Company)

Approach

Mean Teacher is a simple method for semi-supervised learning. It consists of the following steps:

  1. Take a supervised architecture and make a copy of it. Let's call the original model the student and the new one the teacher.
  2. At each training step, use the same minibatch as inputs to both the student and the teacher but add random augmentation or noise to the inputs separately.
  3. Add an additional consistency cost between the student and teacher outputs (after softmax).
  4. Let the optimizer update the student weights normally.
  5. Let the teacher weights be an exponential moving average (EMA) of the student weights. That is, after each training step, update the teacher weights a little bit toward the student weights.

Our contribution is the last step. Laine and Aila [paper] used shared parameters between the student and the teacher, or used a temporal ensemble of teacher predictions. In comparison, Mean Teacher is more accurate and applicable to large datasets.

Mean Teacher model

Mean Teacher works well with modern architectures. Combining Mean Teacher with ResNets, we improved the state of the art in semi-supervised learning on the ImageNet and CIFAR-10 datasets.

ImageNet using 10% of the labels top-5 validation error
Variational Auto-Encoder [paper] 35.42 ± 0.90
Mean Teacher ResNet-152 9.11 ± 0.12
All labels, state of the art [paper] 3.79
CIFAR-10 using 4000 labels test error
CT-GAN [paper] 9.98 ± 0.21
Mean Teacher ResNet-26 6.28 ± 0.15
All labels, state of the art [paper] 2.86

Implementation

There are two implementations, one for TensorFlow and one for PyTorch. The PyTorch version is probably easier to adapt to your needs, since it follows typical PyTorch idioms, and there's a natural place to add your model and dataset. Let me know if anything needs clarification.

Regarding the results in the paper, the experiments using a traditional ConvNet architecture were run with the TensorFlow version. The experiments using residual networks were run with the PyTorch version.

Tips for choosing hyperparameters and other tuning

Mean Teacher introduces two new hyperparameters: EMA decay rate and consistency cost weight. The optimal value for each of these depends on the dataset, the model, and the composition of the minibatches. You will also need to choose how to interleave unlabeled samples and labeled samples in minibatches.

Here are some rules of thumb to get you started:

  • If you are working on a new dataset, it may be easiest to start with only labeled data and do pure supervised training. Then when you are happy with the architecture and hyperparameters, add mean teacher. The same network should work well, although you may want to tune down regularization such as weight decay that you have used with small data.
  • Mean Teacher needs some noise in the model to work optimally. In practice, the best noise is probably random input augmentations. Use whatever relevant augmentations you can think of: the algorithm will train the model to be invariant to them.
  • It's useful to dedicate a portion of each minibatch for labeled examples. Then the supervised training signal is strong enough early on to train quickly and prevent getting stuck into uncertainty. In the PyTorch examples we have a quarter or a half of the minibatch for the labeled examples and the rest for the unlabeled. (See TwoStreamBatchSampler in Pytorch code.)
  • For EMA decay rate 0.999 seems to be a good starting point.
  • You can use either MSE or KL-divergence as the consistency cost function. For KL-divergence, a good consistency cost weight is often between 1.0 and 10.0. For MSE, it seems to be between the number of classes and the number of classes squared. On small datasets we saw MSE getting better results, but KL always worked pretty well too.
  • It may help to ramp up the consistency cost in the beginning over the first few epochs until the teacher network starts giving good predictions.
  • An additional trick we used in the PyTorch examples: Have two seperate logit layers at the top level. Use one for classification of labeled examples and one for predicting the teacher output. And then have an additional cost between the logits of these two predictions. The intent is the same as with the consistency cost rampup: in the beginning the teacher output may be wrong, so loosen the link between the classification prediction and the consistency cost. (See the --logit-distance-cost argument in the PyTorch implementation.)
Comments
  • Adapting to Different Image Size

    Adapting to Different Image Size

    Hello,

    Our team is interested in testing an implementation of the mean-teacher Resnet in PyTorch for a few image classification problems we are working on.

    However, we are having difficulty adapting the network to our image dimensions.

    If I resize our images to 32x32 it runs without error. But, if I change to something else, I get:

    Traceback (most recent call last):
      File "/opt/conda/lib/python3.5/runpy.py", line 193, in _run_module_as_main
        "__main__", mod_spec)
      File "/opt/conda/lib/python3.5/runpy.py", line 85, in _run_code
        exec(code, run_globals)
      File "/mnt/data/nihxr/emo/mean-teacher/pytorch/experiments/rbc_test.py", line 76, in <module>
        run(**run_params)
      File "/mnt/data/nihxr/emo/mean-teacher/pytorch/experiments/rbc_test.py", line 71, in run
        main.main(context)
      File "/mnt/data/nihxr/emo/mean-teacher/pytorch/main.py", line 97, in main
        train(train_loader, model, ema_model, optimizer, epoch, training_log)
      File "/mnt/data/nihxr/emo/mean-teacher/pytorch/main.py", line 225, in train
        ema_model_out = ema_model(ema_input_var)
      File "/opt/conda/lib/python3.5/site-packages/torch/nn/modules/module.py", line 325, in __call__
        result = self.forward(*input, **kwargs)
      File "/opt/conda/lib/python3.5/site-packages/torch/nn/parallel/data_parallel.py", line 68, in forward
        outputs = self.parallel_apply(replicas, inputs, kwargs)
      File "/opt/conda/lib/python3.5/site-packages/torch/nn/parallel/data_parallel.py", line 78, in parallel_apply
        return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
      File "/opt/conda/lib/python3.5/site-packages/torch/nn/parallel/parallel_apply.py", line 67, in parallel_apply
        raise output
      File "/opt/conda/lib/python3.5/site-packages/torch/nn/parallel/parallel_apply.py", line 42, in _worker
        output = module(*input, **kwargs)
      File "/opt/conda/lib/python3.5/site-packages/torch/nn/modules/module.py", line 325, in __call__
        result = self.forward(*input, **kwargs)
      File "/mnt/data/nihxr/emo/mean-teacher/pytorch/mean_teacher/architectures.py", line 158, in forward
        x = self.layer3(x)
      File "/opt/conda/lib/python3.5/site-packages/torch/nn/modules/module.py", line 325, in __call__
        result = self.forward(*input, **kwargs)
      File "/opt/conda/lib/python3.5/site-packages/torch/nn/modules/container.py", line 67, in forward
        input = module(input)
      File "/opt/conda/lib/python3.5/site-packages/torch/nn/modules/module.py", line 325, in __call__
        result = self.forward(*input, **kwargs)
      File "/mnt/data/nihxr/emo/mean-teacher/pytorch/mean_teacher/architectures.py", line 255, in forward
        residual = self.downsample(x)
      File "/opt/conda/lib/python3.5/site-packages/torch/nn/modules/module.py", line 325, in __call__
        result = self.forward(*input, **kwargs)
      File "/mnt/data/nihxr/emo/mean-teacher/pytorch/mean_teacher/architectures.py", line 302, in forward
        x[:, :, 1::2, 1::2]), dim=1)
    RuntimeError: inconsistent tensor sizes at /opt/conda/conda-bld/pytorch_1512382878663/work/torch/lib/THC/generic/THCTensorMath.cu:157
    

    Which makes sense. We're just a little unfamiliar with PyTorch and, speaking for myself, Resnet. So, I thought I would post this question while I was looking into this to see if someone might post an obvious hint that may not be obvious to find.

    Thank you in advance, Tommy

    opened by tjdurant 8
  • The workaround for the issue of tensorflow(>=1.3)

    The workaround for the issue of tensorflow(>=1.3)

    It's the issue of tensorflow(>=1.3): https://github.com/tensorflow/tensorflow/issues/12598

    Workaround: To use an arbitrary initializer for variables, and then assign value to the ref which returned from the initialization pass.

    opened by LyleW 4
  • SVHN - final accuracy

    SVHN - final accuracy

    Hi, I ran your tensorflow code (file train_svhn.py) and the final accuracy was only around 90%. I did not change anything in the code. I ran it as is ! Do you have any suggestions why I do not get the expected 96% ? By the way, I ran it on one GPU.

    opened by boussad83 4
  • AttributeError: 'DataFrame' object has no attribute 'to_msgpack'

    AttributeError: 'DataFrame' object has no attribute 'to_msgpack'

    Traceback (most recent call last): File "main.py", line 423, in main(RunContext(file, 0)) File "main.py", line 104, in main train(train_loader, model, ema_model, optimizer, epoch, training_log) File "main.py", line 310, in train **meters.sums() File "/home/lbl/work/mean-teacher-master/pytorch/mean_teacher/run_context.py", line 34, in record self._record(step, col_val_dict) File "/home/lbl/work/mean-teacher-master/pytorch/mean_teacher/run_context.py", line 45, in _record self.save() File "/home/lbl/work/mean-teacher-master/pytorch/mean_teacher/run_context.py", line 38, in save df.to_msgpack(self.log_file_path, compress='zlib') File "/home/lbl/miniconda3/lib/python3.7/site-packages/pandas/core/generic.py", line 5274, in getattr return object.getattribute(self, name) AttributeError: 'DataFrame' object has no attribute 'to_msgpack'

    my pandas version is 1.0.1 and this function may be removed earlier, so what can i do?

    opened by bolin12 3
  • License?

    License?

    Hi, we'd like to release some code which uses parts of this codebase. However, your code has no license listed. Do you mind adding a license, so that others can use your code? https://help.github.com/articles/licensing-a-repository/

    opened by craffel 3
  • Query regarding input transformation

    Query regarding input transformation

    Hey, I guess input and ema_input transformed versions of the same images, right ?(https://github.com/CuriousAI/mean-teacher/blob/618c84430da22ef3fddc670894802cd4635c9dc2/pytorch/main.py#L208)

    If so, did you guys experiment with using the same input for both model and ema_model ? Does using the same input lead to drop in performance ?

    Thanks !

    opened by Viresh-R 3
  • Finally, which one should I take, teacher or student?

    Finally, which one should I take, teacher or student?

    Hi, I am very impressed with your research.

    It may seem like a stupid, but I'm wondering what should I use the model for validating. I'm confusing with some knowledge distillation (KD) methodologies, which use the terms, teacher and student model. At first I thought those have different meaning (i.e., the word teacher in mean teacher and KD). However, I'm wondering why the EMA model (teacher model) has better performance than the student model, which is supervised-learned with ground truths. (and the slide also tells that the teacher model leads the student model.) Indeed, the experiments results show that the teacher model has better performances than the students.

    1. How can I approach that the EMA-weighted model has better performance than the student model?
    2. So, is it correct that the teacher model is using at the final system?

    Thanks for reading.

    opened by Lilac-wgk 2
  • How to unpack training.msgpack and show the training logs?

    How to unpack training.msgpack and show the training logs?

    Thanks for your inspiring idea and the corresponding code.

    I have run the cifar10 experiments in your code on the AWS cloud. After i trained the network, the data logs were saved in the cloud. Then I downloaded the results files, such as, training.msgpack, but i don't know how to unpack it to show the training logs.

    I have google and searched at stackoverflow. But i still have not find a way to show the logs.

    Would you please show me how to unpack the .msgpack file and show the logs?

    Thanks.

    opened by zhe-meng 2
  • RuntimeError: Expected object of type torch.cuda.FloatTensor but found type torch.cuda.LongTensor for argument #2 'other'

    RuntimeError: Expected object of type torch.cuda.FloatTensor but found type torch.cuda.LongTensor for argument #2 'other'

    The results of running python main.py ...... are showed as following:

    Traceback (most recent call last): File "main.py", line 424, in main(RunContext(file, 0)) File "main.py", line 104, in main train(train_loader, model, ema_model, optimizer, epoch, training_log) File "main.py", line 274, in train meters.update('top1', prec1[0], labeled_minibatch_size) File "/home/gzx/Meanteacher/mean-teacher/pytorch/mean_teacher/utils.py", line 53, in update self.meters[name].update(value, n) File "/home/gzx/Meanteacher/mean-teacher/pytorch/mean_teacher/utils.py", line 86, in update self.sum = self.sum +(val * n) RuntimeError: Expected object of type torch.cuda.FloatTensor but found type torch.cuda.LongTensor for argument #2 'other'

    Is this because I install pytorch by using conda? The pytorch version is 0.4.1

    opened by zhangxgu 2
  • ImageNet instructions

    ImageNet instructions

    Hi, thanks for the great work! I am wondering whether you could upload the instruction to reproduce the ImageNet results. It seems the data preparation for ImageNet is missing. Thanks.

    opened by xinmei9322 2
  • Keep Training but no output

    Keep Training but no output

    I used the suggested command 'python main.py
    --dataset cifar10
    --labels data-local/labels/cifar10/1000_balanced_labels/00.txt
    --arch cifar_shakeshake26
    --consistency 100.0
    --consistency-rampup 5
    --labeled-batch-size 62
    --epochs 180
    --lr-rampdown-epochs 210

    I'm using ubuntu 18.04 python3.6, pytorch 0.3.0. numpy 1.14.2, and cuda8.0.and 2 gtx1080Ti When I run the main.py, it can start training but there is no output information(epochs ,accuracy) during the training process (in a few hours), the only outputs are like this:

    INFO:main:=> creating model 'cifar_shakeshake26' INFO:main:=> creating EMA model 'cifar_shakeshake26' INFO:main: List of model parameters:

    module.conv1.weight 16 * 3 * 3 * 3 = 432 module.layer1.0.conv_a1.weight 96 * 16 * 3 * 3 = 13,824 module.layer1.0.bn_a1.weight 96 = 96 module.layer1.0.bn_a1.bias 96 = 96 ..... module.fc2.weight 10 * 384 = 3,840 module.fc2.bias 10 = 10

    all parameters sum of above = 26,197,316

    I have checked the results folder and there is no checkpoint file in it.

    opened by Lyt859165290 1
  • About two different loss while in training progress

    About two different loss while in training progress

    Hi there, I download this code and adapted them for my semi-supervised segmentation (Pytorch version). And thanks for this genius code you provided!

    But I have a question is that I know that mean-teacher model contains two loss, one is for unsupervised loss for paired labeled data and ground-truth, and the other is for contrast loss.

    And here is the contrast loss I calculated:

    1. get the consistency weight by 10 * sigmoid_rampup(epoch, 5) at each epoch
    2. compute the logits from student and teacher's output

    And then I got contrast loss up to thousands and it seems not right. Is it normal or some bug in my code? Could you give me some advice if you have some idea? Thanks!

    opened by DISAPPEARED13 1
  • How to train the model with unlabeled data?

    How to train the model with unlabeled data?

    I want to transfer the MT framework to a NLP task but I don't understand how to train it with unlabeled data. I have got the idea of the paper, but i'm confusing about the implementation.

        if isinstance(model_out, Variable):
            assert args.logit_distance_cost < 0
            logit1 = model_out
            ema_logit = ema_model_out
        else:
            assert len(model_out) == 2
            assert len(ema_model_out) == 2
            logit1, logit2 = model_out
            ema_logit, _ = ema_model_out
    
        ema_logit = Variable(ema_logit.detach().data, requires_grad=False)
    
        if args.logit_distance_cost >= 0:
            class_logit, cons_logit = logit1, logit2
            res_loss = args.logit_distance_cost * residual_logit_criterion(class_logit, cons_logit) / minibatch_size
            meters.update('res_loss', res_loss.data[0])
        else:
            class_logit, cons_logit = logit1, logit1
            res_loss = 0
    
        class_loss = class_criterion(class_logit, target_var) / minibatch_size
        meters.update('class_loss', class_loss.data[0])
    
        ema_class_loss = class_criterion(ema_logit, target_var) / minibatch_size
        meters.update('ema_class_loss', ema_class_loss.data[0])
    
        if args.consistency:
            consistency_weight = get_current_consistency_weight(epoch)
            meters.update('cons_weight', consistency_weight)
            consistency_loss = consistency_weight * consistency_criterion(cons_logit, ema_logit) / minibatch_size
            meters.update('cons_loss', consistency_loss.data[0])
        else:
            consistency_loss = 0
            meters.update('cons_loss', 0)
    

    I notice that the TwoStreamBatchSampler divides the dataset into labeled part and unlabeled part, but the code above seems handles both labeled and unlabeled data in a universal way. I think only the labeled part of model_out should be used to calculate the class_loss. Did I get it wrong?

    opened by broken-dream 1
  • The Cifar10 dataset link in Readme in meanteacher/pytorch is lost

    The Cifar10 dataset link in Readme in meanteacher/pytorch is lost

    Hi, The dataset preparation command in pytroch version of mean-teacher: pip install git+ssh://[email protected]/pytorch/vision@c31c3d7e0e68e871d2128c8b731698ed3b11b119 is no longer exist. Hope you would update the dataset link later. Thank you!

    opened by TeleRagingFires 0
  • Applying approach for NLP problem

    Applying approach for NLP problem

    I am planning to apply mean-teacher for my problem of token classification. Since adding different noise for teacher and student is really important for the approach, i am confused about how to calculate consistency cost as length of active logits would differ. for e.g. if i use synonym noise then it can happen that it increases the length of the sentence (some tokens maybe replaces by synonym of len 2) when given to teacher model and same augmentation may generate different sentence(of different length) when given to student model.

    opened by tarunbhatiaind 0
Owner
Curious AI
Deep good. Unsupervised better.
Curious AI
OpenFace – a state-of-the art tool intended for facial landmark detection, head pose estimation, facial action unit recognition, and eye-gaze estimation.

OpenFace 2.2.0: a facial behavior analysis toolkit Over the past few years, there has been an increased interest in automatic facial behavior analysis

Tadas Baltrusaitis 5.8k Dec 31, 2022
Quickly comparing your image classification models with the state-of-the-art models (such as DenseNet, ResNet, ...)

Image Classification Project Killer in PyTorch This repo is designed for those who want to start their experiments two days before the deadline and ki

null 349 Dec 8, 2022
UniMoCo: Unsupervised, Semi-Supervised and Full-Supervised Visual Representation Learning

UniMoCo: Unsupervised, Semi-Supervised and Full-Supervised Visual Representation Learning This is the official PyTorch implementation for UniMoCo pape

dddzg 49 Jan 2, 2023
Project looking into use of autoencoder for semi-supervised learning and comparing data requirements compared to supervised learning.

Project looking into use of autoencoder for semi-supervised learning and comparing data requirements compared to supervised learning.

Tom-R.T.Kvalvaag 2 Dec 17, 2021
Hybrid CenterNet - Hybrid-supervised object detection / Weakly semi-supervised object detection

Hybrid-Supervised Object Detection System Object detection system trained by hybrid-supervision/weakly semi-supervision (HSOD/WSSOD): This project is

null 5 Dec 10, 2022
Unified unsupervised and semi-supervised domain adaptation network for cross-scenario face anti-spoofing, Pattern Recognition

USDAN The implementation of Unified unsupervised and semi-supervised domain adaptation network for cross-scenario face anti-spoofing, which is accepte

null 11 Nov 3, 2022
CoSMA: Convolutional Semi-Regular Mesh Autoencoder. From Paper "Mesh Convolutional Autoencoder for Semi-Regular Meshes of Different Sizes"

Mesh Convolutional Autoencoder for Semi-Regular Meshes of Different Sizes Implementation of CoSMA: Convolutional Semi-Regular Mesh Autoencoder arXiv p

Fraunhofer SCAI 10 Oct 11, 2022
State of the Art Neural Networks for Deep Learning

pyradox This python library helps you with implementing various state of the art neural networks in a totally customizable fashion using Tensorflow 2

Ritvik Rastogi 60 May 29, 2022
Code for paper "A Critical Assessment of State-of-the-Art in Entity Alignment" (https://arxiv.org/abs/2010.16314)

A Critical Assessment of State-of-the-Art in Entity Alignment This repository contains the source code for the paper A Critical Assessment of State-of

Max Berrendorf 16 Oct 14, 2022
State of the art Semantic Sentence Embeddings

Contrastive Tension State of the art Semantic Sentence Embeddings Published Paper · Huggingface Models · Report Bug Overview This is the official code

Fredrik Carlsson 88 Dec 30, 2022
QuickAI is a Python library that makes it extremely easy to experiment with state-of-the-art Machine Learning models.

QuickAI is a Python library that makes it extremely easy to experiment with state-of-the-art Machine Learning models.

null 152 Jan 2, 2023
LaneDet is an open source lane detection toolbox based on PyTorch that aims to pull together a wide variety of state-of-the-art lane detection models

LaneDet is an open source lane detection toolbox based on PyTorch that aims to pull together a wide variety of state-of-the-art lane detection models. Developers can reproduce these SOTA methods and build their own methods.

TuZheng 405 Jan 4, 2023
tsai is an open-source deep learning package built on top of Pytorch & fastai focused on state-of-the-art techniques for time series classification, regression and forecasting.

Time series Timeseries Deep Learning Pytorch fastai - State-of-the-art Deep Learning with Time Series and Sequences in Pytorch / fastai

timeseriesAI 2.8k Jan 8, 2023
Deep Text Search is an AI-powered multilingual text search and recommendation engine with state-of-the-art transformer-based multilingual text embedding (50+ languages).

Deep Text Search - AI Based Text Search & Recommendation System Deep Text Search is an AI-powered multilingual text search and recommendation engine w

null 19 Sep 29, 2022
State-of-the-art data augmentation search algorithms in PyTorch

MuarAugment Description MuarAugment is a package providing the easiest way to a state-of-the-art data augmentation pipeline. How to use You can instal

null 43 Dec 12, 2022
A selection of State Of The Art research papers (and code) on human locomotion (pose + trajectory) prediction (forecasting)

A selection of State Of The Art research papers (and code) on human trajectory prediction (forecasting). Papers marked with [W] are workshop papers.

Karttikeya Manglam 40 Nov 18, 2022
A state of the art of new lightweight YOLO model implemented by TensorFlow 2.

CSL-YOLO: A New Lightweight Object Detection System for Edge Computing This project provides a SOTA level lightweight YOLO called "Cross-Stage Lightwe

Miles Zhang 54 Dec 21, 2022
😇A pyTorch implementation of the DeepMoji model: state-of-the-art deep learning model for analyzing sentiment, emotion, sarcasm etc

------ Update September 2018 ------ It's been a year since TorchMoji and DeepMoji were released. We're trying to understand how it's being used such t

Hugging Face 865 Dec 24, 2022