A PyTorch implementation of Sharpness-Aware Minimization for Efficiently Improving Generalization

Overview

sam.pytorch

A PyTorch implementation of Sharpness-Aware Minimization for Efficiently Improving Generalization ( Foret+2020) Paper, Official implementation .

Requirements

  • Python>=3.8
  • PyTorch>=1.7.1

To run the example, you further need

  • homura by pip install -U homura-core==2020.12.0
  • chika by pip install -U chika

Example

python cifar10.py [--optim.name {sam,sgd}] [--model {renst20, wrn28_2}] [--optim.rho 0.05]

Results: Test Accuracy (CIFAR-10)

Model SAM SGD
ResNet-20 93.5 93.2
WRN28-2 95.8 95.4
ResNeXT29 96.4 95.8

SAM needs double forward passes per each update, thus training with SAM is slower than training with SGD. In case of ResNet-20 training, 80 mins vs 50 mins on my environment. Additional options --use_amp --jit_model may slightly accelerates the training.

Usage

SAMSGD can be used as a drop-in replacement of PyTorch optimizers with closures. Also, it is compatible with lr_scheduler and has state_dict and load_state_dict.

from sam import SAMSGD

optimizer = SAMSGD(model.parameters(), lr=1e-1, rho=0.05)

for input, target in dataset:
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = loss_f(output, target)
        loss.backward()
        return loss


    loss = optimizer.step(closure)

Citation

@ARTICLE{2020arXiv201001412F,
    author = {{Foret}, Pierre and {Kleiner}, Ariel and {Mobahi}, Hossein and {Neyshabur}, Behnam},
    title = "{Sharpness-Aware Minimization for Efficiently Improving Generalization}",
    year = 2020,
    eid = {arXiv:2010.01412},
    eprint = {2010.01412},
}

@software{sampytorch
    author = {Ryuichiro Hataya},
    titile = {sam.pytorch},
    url    = {https://github.com/moskomule/sam.pytorch},
    year   = {2020}
}
You might also like...
Codes accompanying the paper "Learning Nearly Decomposable Value Functions with Communication Minimization" (ICLR 2020)

NDQ: Learning Nearly Decomposable Value Functions with Communication Minimization Note This codebase accompanies paper Learning Nearly Decomposable Va

Official implementation of
Official implementation of "Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets" (CVPR2021)

Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets This is the official implementation of "Towards Good Pract

Improving Transferability of Representations via Augmentation-Aware Self-Supervision
Improving Transferability of Representations via Augmentation-Aware Self-Supervision

Improving Transferability of Representations via Augmentation-Aware Self-Supervision Accepted to NeurIPS 2021 TL;DR: Learning augmentation-aware infor

Official PyTorch implementation of the Fishr regularization for out-of-distribution generalization
Official PyTorch implementation of the Fishr regularization for out-of-distribution generalization

Fishr: Invariant Gradient Variances for Out-of-distribution Generalization Official PyTorch implementation of the Fishr regularization for out-of-dist

Unofficial Pytorch Lightning implementation of Contrastive Syn-to-Real Generalization (ICLR, 2021)

Unofficial Pytorch Lightning implementation of Contrastive Syn-to-Real Generalization (ICLR, 2021)

PyTorch implementation of Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets
PyTorch implementation of Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets

Simple PyTorch Implementation of "Grokking" Implementation of Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets Usage Running

[CVPR 2022] Pytorch implementation of
[CVPR 2022] Pytorch implementation of "Templates for 3D Object Pose Estimation Revisited: Generalization to New objects and Robustness to Occlusions" paper

template-pose Pytorch implementation of "Templates for 3D Object Pose Estimation Revisited: Generalization to New objects and Robustness to Occlusions

Theano is a Python library that allows you to define, optimize, and evaluate mathematical expressions involving multi-dimensional arrays efficiently. It can use GPUs and perform efficient symbolic differentiation.

============================================================================================================ `MILA will stop developing Theano https:

Theano is a Python library that allows you to define, optimize, and evaluate mathematical expressions involving multi-dimensional arrays efficiently. It can use GPUs and perform efficient symbolic differentiation.

============================================================================================================ `MILA will stop developing Theano https:

Comments
  • Reproducing WRN-28-10 (SAM) for SVHN dataset

    Reproducing WRN-28-10 (SAM) for SVHN dataset

    I am trying to reproduce the results for WRN-28-10 (SAM) trained on 10-class classification SVHN dataset (Percentage Error 0.99) - https://paperswithcode.com/sota/image-classification-on-svhn

    I'm able to train WRN-28-10 using https://github.com/hysts/pytorch_wrn (Modified the script to incorporate SAM into it)

    I'm achieving a test accuracy of 93%. How can I replicate the SOTA Percentage Error 0.99 for WRN-28-10 (SAM). Which hyperparameters do I use?

    Any help is appreciated!!

    opened by akshaydp1995 1
Owner
Ryuichiro Hataya
PhD student at UTokyo and RA at RIKEN AIP / focusing on DL and ML
Ryuichiro Hataya
A PyTorch-based open-source framework that provides methods for improving the weakly annotated data and allows researchers to efficiently develop and compare their own methods.

Knodle (Knowledge-supervised Deep Learning Framework) - a new framework for weak supervision with neural networks. It provides a modularization for se

null 93 Nov 6, 2022
A PyTorch implementation of the paper Mixup: Beyond Empirical Risk Minimization in PyTorch

Mixup: Beyond Empirical Risk Minimization in PyTorch This is an unofficial PyTorch implementation of mixup: Beyond Empirical Risk Minimization. The co

Harry Yang 121 Dec 17, 2022
Pytorch implementation of the AAAI 2022 paper "Cross-Domain Empirical Risk Minimization for Unbiased Long-tailed Classification"

[AAAI22] Cross-Domain Empirical Risk Minimization for Unbiased Long-tailed Classification We point out the overlooked unbiasedness in long-tailed clas

PatatiPatata 28 Oct 18, 2022
Official pytorch implementation of "Feature Stylization and Domain-aware Contrastive Loss for Domain Generalization" ACMMM 2021 (Oral)

Feature Stylization and Domain-aware Contrastive Loss for Domain Generalization This is an official implementation of "Feature Stylization and Domain-

null 22 Sep 22, 2022
Cascaded Deep Video Deblurring Using Temporal Sharpness Prior and Non-local Spatial-Temporal Similarity

This repository is the official PyTorch implementation of Cascaded Deep Video Deblurring Using Temporal Sharpness Prior and Non-local Spatial-Temporal Similarity

hippopmonkey 4 Dec 11, 2022
Implement of "Training deep neural networks via direct loss minimization" in PyTorch for 0-1 loss

This is the implementation of "Training deep neural networks via direct loss minimization" published at ICML 2016 in PyTorch. The implementation targe

Cuong Nguyen 1 Jan 18, 2022
This repo includes our code for evaluating and improving transferability in domain generalization (NeurIPS 2021)

Transferability for domain generalization This repo is for evaluating and improving transferability in domain generalization (NeurIPS 2021), based on

gordon 9 Nov 29, 2022
ICLR21 Tent: Fully Test-Time Adaptation by Entropy Minimization

⛺️ Tent: Fully Test-Time Adaptation by Entropy Minimization This is the official project repository for Tent: Fully-Test Time Adaptation by Entropy Mi

Dequan Wang 204 Dec 25, 2022
Tilted Empirical Risk Minimization (ICLR '21)

Tilted Empirical Risk Minimization This repository contains the implementation for the paper Tilted Empirical Risk Minimization ICLR 2021 Empirical ri

Tian Li 40 Nov 28, 2022