2.86% and 15.85% on CIFAR-10 and CIFAR-100

Overview

Shake-Shake regularization

This repository contains the code for the paper Shake-Shake regularization. This arxiv paper is an extension of Shake-Shake regularization of 3-branch residual networks which was accepted as a workshop contribution at ICLR 2017.

The code is based on fb.resnet.torch.

Table of Contents

  1. Introduction
  2. Results
  3. Usage
  4. Contact

Introduction

The method introduced in this paper aims at helping deep learning practitioners faced with an overfit problem. The idea is to replace, in a multi-branch network, the standard summation of parallel branches with a stochastic affine combination. Applied to 3-branch residual networks, shake-shake regularization improves on the best single shot published results on CIFAR-10 and CIFAR-100 by reaching test errors of 2.86% and 15.85%.

shake-shake

Figure 1: Left: Forward training pass. Center: Backward training pass. Right: At test time.

Bibtex:

@article{Gastaldi17ShakeShake,
   title = {Shake-Shake regularization},
   author = {Xavier Gastaldi},
   journal = {arXiv preprint arXiv:1705.07485},
   year = 2017,
}

Results on CIFAR-10

The base network is a 26 2x32d ResNet (i.e. the network has a depth of 26, 2 residual branches and the first residual block has a width of 32). "Shake" means that all scaling coefficients are overwritten with new random numbers before the pass. "Even" means that all scaling coefficients are set to 0.5 before the pass. "Keep" means that we keep, for the backward pass, the scaling coefficients used during the forward pass. "Batch" means that, for each residual block, we apply the same scaling coefficient for all the images in the mini-batch. "Image" means that, for each residual block, we apply a different scaling coefficient for each image in the mini-batch. The numbers in the Table below represent the average of 3 runs except for the 96d models which were run 5 times.

Forward Backward Level 26 2x32d 26 2x64d 26 2x96d 26 2x112d
Even Even n\a 4.27 3.76 3.58 -
Even Shake Batch 4.44 - -
Shake Keep Batch 4.11 - - -
Shake Even Batch 3.47 3.30 - -
Shake Shake Batch 3.67 3.07 - -
Even Shake Image 4.11 - - -
Shake Keep Image 4.09 - - -
Shake Even Image 3.47 3.20 - -
Shake Shake Image 3.55 2.98 2.86 2.821

Table 1: Error rates (%) on CIFAR-10 (Top 1 of the last epoch)

Other results

Cifar-100:
29 2x4x64d: 15.85%

Reduced CIFAR-10:
26 2x96d: 17.05%1

SVHN:
26 2x96d: 1.4%1

Reduced SVHN:
26 2x96d: 12.32%1

Usage

  1. Install fb.resnet.torch, optnet and lua-stdlib.
  2. Download Shake-Shake
git clone https://github.com/xgastaldi/shake-shake.git
  1. Copy the elements in the shake-shake folder and paste them in the fb.resnet.torch folder. This will overwrite 5 files (main.lua, train.lua, opts.lua, checkpoints.lua and models/init.lua) and add 4 new files (models/shakeshake.lua, models/shakeshakeblock.lua, models/mulconstantslices.lua and models/shakeshaketable.lua).
  2. To reproduce CIFAR-10 results (e.g. 26 2x32d "Shake-Shake-Image" ResNet) on 2 GPUs:
CUDA_VISIBLE_DEVICES=0,1 th main.lua -dataset cifar10 -nGPU 2 -batchSize 128 -depth 26 -shareGradInput false -optnet true -nEpochs 1800 -netType shakeshake -lrShape cosine -baseWidth 32 -LR 0.2 -forwardShake true -backwardShake true -shakeImage true

To get comparable results using 1 GPU, please change the batch size and the corresponding learning rate:

CUDA_VISIBLE_DEVICES=0 th main.lua -dataset cifar10 -nGPU 1 -batchSize 64 -depth 26 -shareGradInput false -optnet true -nEpochs 1800 -netType shakeshake -lrShape cosine -baseWidth 32 -LR 0.1 -forwardShake true -backwardShake true -shakeImage true

A 26 2x96d "Shake-Shake-Image" ResNet can be trained on 2 GPUs using:

CUDA_VISIBLE_DEVICES=0,1 th main.lua -dataset cifar10 -nGPU 2 -batchSize 128 -depth 26 -shareGradInput false -optnet true -nEpochs 1800 -netType shakeshake -lrShape cosine -baseWidth 96 -LR 0.2 -forwardShake true -backwardShake true -shakeImage true
  1. To reproduce CIFAR-100 results (e.g. 29 2x4x64d "Shake-Even-Image" ResNeXt) on 2 GPUs:
CUDA_VISIBLE_DEVICES=0,1 th main.lua -dataset cifar100 -depth 29 -baseWidth 64 -groups 4 -weightDecay 5e-4 -batchSize 32 -netType shakeshake -nGPU 2 -LR 0.025 -nThreads 8 -shareGradInput true -nEpochs 1800 -lrShape cosine -forwardShake true -backwardShake false -shakeImage true

Note

Changes made to fb.resnet.torch files:

main.lua
Ln 17, 54-59, 81-100: Adds a log

train.lua
Ln 36-38 58-60 206-213: Adds the cosine learning rate function
Ln 88-89: Adds the learning rate to the elements printed on screen

opts.lua
Ln 21-64: Adds Shake-Shake options

checkpoints.lua
Ln 15-16: Adds require 'models/shakeshakeblock', 'models/shakeshaketable' and require 'std'
Ln 60-61: Avoids using the fb.resnet.torch deepcopy (it doesn't seem to be compatible with the BN in shakeshakeblock) and replaces it with the deepcopy from stdlib
Ln 67-86: Saves only the last model

models/init.lua
Ln 91-92: Adds require 'models/mulconstantslices', require 'models/shakeshakeblock' and require 'models/shakeshaketable'

The main model is in shakeshake.lua. The residual block model is in shakeshakeblock.lua. mulconstantslices.lua is just an extension of nn.mulconstant that multiplies elements of a vector with image slices of a mini-batch tensor. shakeshaketable.lua contains the method used for CIFAR-100 since the ResNeXt code uses a table implementation instead of a module version.

Reimplementations

Pytorch
https://github.com/hysts/pytorch_shake_shake

Tensorflow
https://github.com/tensorflow/models/blob/master/research/autoaugment/
https://github.com/tensorflow/tensor2tensor

Contact

xgastaldi.mba2011 at london.edu
Any discussions, suggestions and questions are welcome!

References

(1) Ekin D. Cubuk, Barret Zoph, Dandelion Mane, Vijay Vasudevan, and Quoc V. Le. AutoAugment: Learning Augmentation Policies from Data. In arXiv:1805.09501, May 2018.

Comments
  • the question about the epoch number of data converging

    the question about the epoch number of data converging

    Using the same hyper parameter :22632 the data will converge at about 1850 epoch.

    After several times of training using other hyper parameter, however, the data will converge at the number outnumber 3500 using the same hyper parameter:22632.

    Two experiments have the same hyper parameter and the same code. The only difference of the second experiment related to the first one is that we have trained the data several times.

    opened by xuyifeng-nwpu 6
  • Question regarding Pooling

    Question regarding Pooling

    I have a small question (I do not know Torch).

    In this code when you make skip connection when decreasing resolution:

     -- Skip path #1
     s1 = nn.Sequential()
     s1:add(nn.SpatialAveragePooling(1, 1, stride, stride))
     s1:add(Convolution(nInputPlane, nOutputPlane/2, 1,1, 1,1, 0,0))
    
     -- Skip path #2
     s2 = nn.Sequential()
     -- Shift the tensor by one pixel right and one pixel down (to make the 2nd path "see" different pixels)
     s2:add(nn.SpatialZeroPadding(1, -1, 1, -1))
     s2:add(nn.SpatialAveragePooling(1, 1, stride, stride))
    

    Skip path #1 Will take 'top left' pixel in each 2x2 square if I understand it correctly

    Skip path #2 Will take 'bottom right' pixel in each 2x2 square. Here is the point I do not understand. If this code first appends zeros to top and left parts of 'feature image' then it will have lot of pixels with value '0' after this downsampling. It would be better to append zeros to right and bottom and remove first row and first column. So instead of s2:add(nn.SpatialZeroPadding(1, -1, 1, -1)) use s2:add(nn.SpatialZeroPadding(-1, 1, -1, 1)).

    Tell me if it works right now as I described, I might be wrong because as I said I don't know Torch to well.

    opened by PatrykChrabaszcz 6
  • Why  is the test top1 different?

    Why is the test top1 different?

    I changed the shakeshakeblock.lua, then run the code with 400 epochs. The next picture was the log text within the process of training. a

    After 400 epochs train, the code CUDA_VISIBLE_DEVICES=0,1,2,3 th main.lua -dataset cifar10 -nGPU 4 -testOnly true -retrain ./checkpoints/model_best.t7run. The result was Results top1: 3.670 top5: 0.020

    Qestion1: Why the test top 1 of "testOnly" was less than the test top1 in the process of training. Qestion2: What was the difference of the best test top1 , last epoch's test top1 and the "testOnly"'s top1? Qestion3: Because the top1(3.67) was generated by the network's model, can I consider my model had the performance :top1 3.67?

    opened by xuyifeng-nwpu 4
  • CIFAR 100 training too slow

    CIFAR 100 training too slow

    Hi, I am trying to get the reproduced result in CIFAR 100.

    I am using the script in README

    CUDA_VISIBLE_DEVICES=0,1 th main.lua -dataset cifar100 -depth 29 -baseWidth 64 -groups 4 -weightDecay 5e-4 -batchSize 32 -netType shakeshake -nGPU 2 -LR 0.025 -nThreads 8 -shareGradInput true -nEpochs 1800 -lrShape cosine -forwardShake true -backwardShake false -shakeImage true
    

    I checked that the Top1 error is 99 at epoch 111, and I'm afraid that it will not converge in the future.

    I am running with up-to-date Torch source (built from source), copied necessary files from fb.resnet in Ubuntu 14.04 / CUDA 8 / CuDNN v4

    Should I wait with patience? or is there any extra tricks to see faster convergence?

    opened by Jongchan 3
  • CUDA out of memory when checkpointing

    CUDA out of memory when checkpointing

    Hi,

    When running the code the model trains for 1 epoch without running out of memory but while checkpointing the first time it tries to make a copy in GPU (not CPU mode). I guess there is a reason to change from default checkpointing done in fb.resnet.torch. Can you please explain the reason.

    Thanks

    opened by arunpatala 3
  • Two improving approchs

    Two improving approchs

    I think there are two approchs to improve the accuracy. Whether these methods were feasible?

    The first method: The validation dataset splitted from train dataset. Then the adaptive learning rate automatically adjusted with the validation accuracy.

    The second method: In the process of training, the test top1 once is lower than a fixed value the leaning rate of the epochs after this epoch settled to zero.

    For example, the code with parameter( nEpochs 400) run ,the log file is as follows. epoch   test top 1   learning rate 370    3.62    0.02 .......    ......    ....... 400    3.82    0.00

    In the log file , the best test top 1 was 3.62,but the result generated at the 370th epoch. If the learning rate between epoch 371 and 400 set as zero, the test top 1 of epoch between 371 and 400 all should be 3.62 ? I had experimented this method and found the test top 1 after the 371th epoch still slightly surge/change.

    Can you give me some suggestion about above two methods? Do you compare the adapting learning rate updating method such as rmsprop,adadelta with SGD? Thank you very much!

    opened by xuyifeng-nwpu 1
  • Error rates

    Error rates

    In the readme.md, you writted :"Table 1: Error rates (%) on CIFAR-10 (Top 1 of the last epoch)".

    However the best error rate may not in the last epoch. For example when the parameter "32296" was used , the top1 error of last epoch and the count the sixth of the epoch were respectively 2.86 and 2.79.

    Can I consider the best top1 error is 2.79? How I select he best top1 error?

    opened by xuyifeng-nwpu 1
  • Could you give me some suggestions about some possible ways to improve test accuracy.

    Could you give me some suggestions about some possible ways to improve test accuracy.

    Thank your sharing your code! Your result is the best one I have seen. Now i try to improve the test error based on your code. Could you give me some suggestions about some possible ways to improve test accuracy?

    You can connect me with my private email: [email protected].

    Thanks.

    opened by xuyifeng-nwpu 1
  • The third column of results lists the wrong network architecture

    The third column of results lists the wrong network architecture

    The columns are listed as: Forward Backward Level 26 2x32d 26 2x64d 26 2x32d in the README but are Forward Backward Level 26 2x32d 26 2x64d 26 2x96d in the paper.

    opened by Islandman93 1
CIFAR-10_train-test - training and testing codes for dataset CIFAR-10

CIFAR-10_train-test - training and testing codes for dataset CIFAR-10

Frederick Wang 3 Apr 26, 2022
Everything you want about DP-Based Federated Learning, including Papers and Code. (Mechanism: Laplace or Gaussian, Dataset: femnist, shakespeare, mnist, cifar-10 and fashion-mnist. )

Differential Privacy (DP) Based Federated Learning (FL) Everything about DP-based FL you need is here. (所有你需要的DP-based FL的信息都在这里) Code Tip: the code o

wenzhu 83 Dec 24, 2022
Implementation of Squeezenet in pytorch, pretrained models on Cifar 10 data to come

Pytorch Squeeznet Pytorch implementation of Squeezenet model as described in https://arxiv.org/abs/1602.07360 on cifar-10 Data. The definition of Sque

gaurav pathak 86 Oct 28, 2022
TensorFlow2 Classification Model Zoo playing with TensorFlow2 on the CIFAR-10 dataset.

Training CIFAR-10 with TensorFlow2(TF2) TensorFlow2 Classification Model Zoo. I'm playing with TensorFlow2 on the CIFAR-10 dataset. Architectures LeNe

Chia-Hung Yuan 16 Sep 27, 2022
Training Cifar-10 Classifier Using VGG16

opevcvdl-hw3 This project uses pytorch and Qt to achieve the requirements. Version Python 3.6 opencv-contrib-python 3.4.2.17 Matplotlib 3.1.1 pyqt5 5.

Kenny Cheng 3 Aug 17, 2022
Training a deep learning model on the noisy CIFAR dataset

Training-a-deep-learning-model-on-the-noisy-CIFAR-dataset This repository contai

null 1 Jun 14, 2022
AugLy is a data augmentations library that currently supports four modalities (audio, image, text & video) and over 100 augmentations

AugLy is a data augmentations library that currently supports four modalities (audio, image, text & video) and over 100 augmentations. Each modality’s augmentations are contained within its own sub-library. These sub-libraries include both function-based and class-based transforms, composition operators, and have the option to provide metadata about the transform applied, including its intensity.

Facebook Research 4.6k Jan 9, 2023
A collection of 100 Deep Learning images and visualizations

A collection of Deep Learning images and visualizations. The project has been developed by the AI Summer team and currently contains almost 100 images.

AI Summer 65 Sep 12, 2022
VIL-100: A New Dataset and A Baseline Model for Video Instance Lane Detection (ICCV 2021)

Preparation Please see dataset/README.md to get more details about our datasets-VIL100 Please see INSTALL.md to install environment and evaluation too

null 82 Dec 15, 2022
torchlm is aims to build a high level pipeline for face landmarks detection, it supports training, evaluating, exporting, inference(Python/C++) and 100+ data augmentations

??A high level pipeline for face landmarks detection, supports training, evaluating, exporting, inference and 100+ data augmentations, compatible with torchvision and albumentations, can easily install with pip.

DefTruth 142 Dec 25, 2022
GuideDog is an AI/ML-based mobile app designed to assist the lives of the visually impaired, 100% voice-controlled

Guidedog Authors: Kyuhee Jo, Steven Gunarso, Jacky Wang, Raghav Sharma GuideDog is an AI/ML-based mobile app designed to assist the lives of the visua

Kyuhee Jo 5 Nov 24, 2021
Erpnext app for make employee salary on payroll entry based on one or more project with percentage for all project equal 100 %

Project Payroll this app for make payroll for employee based on projects like project on 30 % and project 2 70 % as account dimension it makes genral

Ibrahim Morghim 8 Jan 2, 2023
Dark Finix: All in one hacking framework with almost 100 tools

Dark Finix - Hacking Framework. Dark Finix is a all in one hacking framework wit

Md. Nur habib 2 Feb 18, 2022
Pre-trained BERT Models for Ancient and Medieval Greek, and associated code for LaTeCH 2021 paper titled - "A Pilot Study for BERT Language Modelling and Morphological Analysis for Ancient and Medieval Greek"

Ancient Greek BERT The first and only available Ancient Greek sub-word BERT model! State-of-the-art post fine-tuning on Part-of-Speech Tagging and Mor

Pranaydeep Singh 22 Dec 8, 2022
An image base contains 490 images for learning (400 cars and 90 boats), and another 21 images for testingAn image base contains 490 images for learning (400 cars and 90 boats), and another 21 images for testing

SVM Données Une base d’images contient 490 images pour l’apprentissage (400 voitures et 90 bateaux), et encore 21 images pour fait des tests. Prétrait

Achraf Rahouti 3 Nov 30, 2021
Scalable, Portable and Distributed Gradient Boosting (GBDT, GBRT or GBM) Library, for Python, R, Java, Scala, C++ and more. Runs on single machine, Hadoop, Spark, Dask, Flink and DataFlow

eXtreme Gradient Boosting Community | Documentation | Resources | Contributors | Release Notes XGBoost is an optimized distributed gradient boosting l

Distributed (Deep) Machine Learning Community 23.6k Dec 31, 2022