Code used to generate the results appearing in "Train longer, generalize better: closing the generalization gap in large batch training of neural networks"

Overview

Train longer, generalize better - Big batch training

This is a code repository used to generate the results appearing in "Train longer, generalize better: closing the generalization gap in large batch training of neural networks" By Elad Hoffer, Itay Hubara and Daniel Soudry.

It is based off convNet.pytorch with some helpful options such as:

  • Training on several datasets
  • Complete logging of trained experiment
  • Graph visualization of the training/validation loss and accuracy
  • Definition of preprocessing and optimization regime for each model

Dependencies

Data

  • Configure your dataset path at data.py.
  • To get the ILSVRC data, you should register on their site for access: http://www.image-net.org/

Experiment examples

python main_normal.py --dataset cifar10 --model resnet --save cifar10_resnet44_bs2048_lr_fix --epochs 100 --b 2048 --lr_bb_fix;
python main_normal.py --dataset cifar10 --model resnet --save cifar10_resnet44_bs2048_regime_adaptation --epochs 100 --b 2048 --lr_bb_fix --regime_bb_fix;
python main_gbn.py --dataset cifar10 --model resnet --save cifar10_resnet44_bs2048_ghost_bn256 --epochs 100 --b 2048 --lr_bb_fix --mini-batch-size 256;
python main_normal.py --dataset cifar100 --model resnet --save cifar100_wresnet16_4_bs1024_regime_adaptation --epochs 100 --b 1024 --lr_bb_fix --regime_bb_fix;
python main_gbn.py --model mnist_f1 --dataset mnist --save mnist_baseline_bs4096_gbn --epochs 50 --b 4096 --lr_bb_fix --no-regime_bb_fix --mini-batch-size 128;
  • See run_experiments.sh for more examples

Model configuration

Network model is defined by writing a .py file in models folder, and selecting it using the model flag. Model function must be registered in models/__init__.py The model function must return a trainable network. It can also specify additional training options such optimization regime (either a dictionary or a function), and input transform modifications.

e.g for a model definition:

class Model(nn.Module):

    def __init__(self, num_classes=1000):
        super(Model, self).__init__()
        self.model = nn.Sequential(...)

        self.regime = {
            0: {'optimizer': 'SGD', 'lr': 1e-2,
                'weight_decay': 5e-4, 'momentum': 0.9},
            15: {'lr': 1e-3, 'weight_decay': 0}
        }

        self.input_transform = {
            'train': transforms.Compose([...]),
            'eval': transforms.Compose([...])
        }
    def forward(self, inputs):
        return self.model(inputs)

 def model(**kwargs):
        return Model()
You might also like...
This repository contains the code and models necessary to replicate the results of paper:  How to Robustify Black-Box ML Models? A Zeroth-Order Optimization Perspective
This repository contains the code and models necessary to replicate the results of paper: How to Robustify Black-Box ML Models? A Zeroth-Order Optimization Perspective

Black-Box-Defense This repository contains the code and models necessary to replicate the results of our recent paper: How to Robustify Black-Box ML M

AI virtual gym is an AI program which can be used to exercise and can be used to see if we are doing the exercises

AI virtual gym is an AI program which can be used to exercise and can be used to see if we are doing the exercises

A scientific and useful toolbox, which contains practical and effective long-tail related tricks with extensive experimental results

Bag of tricks for long-tailed visual recognition with deep convolutional neural networks This repository is the official PyTorch implementation of AAA

We evaluate our method on different datasets (including ShapeNet, CUB-200-2011, and Pascal3D+) and achieve state-of-the-art results, outperforming all the other supervised and unsupervised methods and 3D representations, all in terms of performance, accuracy, and training time. Pytorch implementation for reproducing StackGAN_v2 results in the paper StackGAN++: Realistic Image Synthesis with Stacked Generative Adversarial Networks
Pytorch implementation for reproducing StackGAN_v2 results in the paper StackGAN++: Realistic Image Synthesis with Stacked Generative Adversarial Networks

StackGAN-v2 StackGAN-v1: Tensorflow implementation StackGAN-v1: Pytorch implementation Inception score evaluation Pytorch implementation for reproduci

[ICCV 2021]  Official Pytorch implementation for Discriminative Region-based Multi-Label Zero-Shot Learning SOTA results on NUS-WIDE and OpenImages
[ICCV 2021] Official Pytorch implementation for Discriminative Region-based Multi-Label Zero-Shot Learning SOTA results on NUS-WIDE and OpenImages

Discriminative Region-based Multi-Label Zero-Shot Learning (ICCV 2021) [arXiv][Project page coming soon] Sanath Narayan*, Akshita Gupta*, Salman Kh

[ICCV 2021]  Official Pytorch implementation for Discriminative Region-based Multi-Label Zero-Shot Learning SOTA results on NUS-WIDE and OpenImages
[ICCV 2021] Official Pytorch implementation for Discriminative Region-based Multi-Label Zero-Shot Learning SOTA results on NUS-WIDE and OpenImages

Discriminative Region-based Multi-Label Zero-Shot Learning (ICCV 2021) [arXiv][Project page coming soon] Sanath Narayan*, Akshita Gupta*, Salman Kh

An easy way to build PyTorch datasets. Modularly build datasets and automatically cache processed results

EasyDatas An easy way to build PyTorch datasets. Modularly build datasets and automatically cache processed results Installation pip install git+https

A repository that shares tuning results of trained models generated by TensorFlow / Keras. Post-training quantization (Weight Quantization, Integer Quantization, Full Integer Quantization, Float16 Quantization), Quantization-aware training. TensorFlow Lite. OpenVINO. CoreML. TensorFlow.js. TF-TRT. MediaPipe. ONNX. [.tflite,.h5,.pb,saved_model,tfjs,tftrt,mlmodel,.xml/.bin, .onnx]
Comments
  • Problems with reproducing the experiments

    Problems with reproducing the experiments

    Hello,

    I'm having trouble reproducing the experiments from your paper using code from this repository. Let's focus on Table 1 from the paper, column +GBN and experiments C1, Resnet44, C3 (although I have problems with more results, including baselines). Here are the commands I use to run these 3 experiments: python main_gbn.py --dataset cifar10 --model cifar10_shallow --save c10_shallow_11 --epochs 200 --b 4096 --lr_bb_fix --mini-batch-size 128 python main_gbn.py --dataset cifar10 --model resnet --save resnet_11 --epochs 200 --b 4096 --lr_bb_fix --mini-batch-size 128 python main_gbn.py --dataset cifar100 --model cifar100_shallow --save c100_shallow_11 --epochs 200 --b 4096 --lr_bb_fix --mini-batch-size 128 So I set batch size to 4096 and ghost batch size to 128 as instructed by the paper; the rest of the hyperparameters (number of epochs, learning rate schedule, momentum, gradient clipping constant, weight decay) remain as set in the code.

    Here are the results that I get, compared to results from the paper: C1 LB + LR + GBN, last epoch: 75.07 +/- 0.10; best epoch: 75.41 +/- 0.11; in paper: 86.4 Resnet44 LB + LR + GBN, last epoch: 85.21 +/- 0.81; best epoch: 85.63 +/- 0.76; in paper: 90.50 C3 LB + LR + GBN, last epoch: 27.33 +/- 0.11; best epoch: 27.63 +/- 0.11; in paper: 57.5

    While in the last case training for more epochs would improve the results, in the second case it pretty much flattens out. I suspect that learning rate schedules are the main culprits, as by manipulating them I was able to enhance the results by far.

    Can you please help me reproduce the experiments and publish the hyparparameters from your experiments? Also, it would be helpful if you published which versions of the Python packages you had used in the original experiments - I use PyTorch 0.3.1 as recommended by the Smoothout paper repo (https://github.com/wenwei202/smoothout), which is forked from your repository, I believe.

    Thanks!

    opened by zajaczajac 0
  • what's the difference between GBN & BN used in framework?

    what's the difference between GBN & BN used in framework?

    I've read your paper. But I don't understand the difference between GBN & BN used in framework. In my understanding, GBN does BN with local data. For distributed frameworks, they also only do BN with local data. So can you explain it please?

    opened by yljylj 2
Owner
Elad Hoffer
Elad Hoffer
Code for sound field predictions in domains with impedance boundaries. Used for generating results from the paper

Code for sound field predictions in domains with impedance boundaries. Used for generating results from the paper

DTU Acoustic Technology Group 11 Dec 17, 2022
Accelerated SMPL operation, commonly used in generate 3D human mesh, STAR included.

SMPL2 An enchanced and accelerated SMPL operation which commonly used in 3D human mesh generation. It takes a poses, shapes, cam_trans as inputs, outp

JinTian 20 Oct 17, 2022
Capture all information throughout your model's development in a reproducible way and tie results directly to the model code!

Rubicon Purpose Rubicon is a data science tool that captures and stores model training and execution information, like parameters and outcomes, in a r

Capital One 97 Jan 3, 2023
This code reproduces the results of the paper, "Measuring Data Leakage in Machine-Learning Models with Fisher Information"

Fisher Information Loss This repository contains code that can be used to reproduce the experimental results presented in the paper: Awni Hannun, Chua

Facebook Research 43 Dec 30, 2022
This repository contains the source code and data for reproducing results of Deep Continuous Clustering paper

Deep Continuous Clustering Introduction This is a Pytorch implementation of the DCC algorithms presented in the following paper (paper): Sohil Atul Sh

Sohil Shah 197 Nov 29, 2022
Code to reproduce the results for Compositional Attention: Disentangling Search and Retrieval.

Compositional-Attention This repository contains the official implementation for the paper Compositional Attention: Disentangling Search and Retrieval

Sarthak Mittal 17 Oct 23, 2021
Code to reproduce the results in the paper "Tensor Component Analysis for Interpreting the Latent Space of GANs".

Tensor Component Analysis for Interpreting the Latent Space of GANs [ paper | project page ] Code to reproduce the results in the paper "Tensor Compon

James Oldfield 4 Jun 17, 2022
Code to reproduce the results for Statistically Robust Neural Network Classification, published in UAI 2021

Code to reproduce the results for Statistically Robust Neural Network Classification, published in UAI 2021

null 1 Jun 2, 2022
The LaTeX and Python code for generating the paper, experiments' results and visualizations reported in each paper is available (whenever possible) in the paper's directory

This repository contains the software implementation of most algorithms used or developed in my research. The LaTeX and Python code for generating the

João Fonseca 3 Jan 3, 2023
This repository contains the code and models necessary to replicate the results of paper: How to Robustify Black-Box ML Models? A Zeroth-Order Optimization Perspective

Black-Box-Defense This repository contains the code and models necessary to replicate the results of our recent paper: How to Robustify Black-Box ML M

OPTML Group 2 Oct 5, 2022