On the Variance of the Adaptive Learning Rate and Beyond

Overview

License Travis-CI

RAdam

On the Variance of the Adaptive Learning Rate and Beyond

We are in an early-release beta. Expect some adventures and rough edges.

Table of Contents

Introduction

If warmup is the answer, what is the question?

The learning rate warmup for Adam is a must-have trick for stable training in certain situations (or eps tuning). But the underlying mechanism is largely unknown. In our study, we suggest one fundamental cause is the large variance of the adaptive learning rates, and provide both theoretical and empirical support evidence.

In addition to explaining why we should use warmup, we also propose RAdam, a theoretically sound variant of Adam.

Motivation

As shown in Figure 1, we assume that gradients follow a normal distribution (mean: \mu, variance: 1). The variance of the adaptive learning rate is simulated and plotted in Figure 1 (blue curve). We observe that the adaptive learning rate has a large variance in the early stage of training.

When using a Transformer for NMT, a warmup stage is usually required to avoid convergence problems (e.g., Adam-vanilla converges around 500 PPL in Figure 2, while Adam-warmup successfully converges under 10 PPL). In further explorations, we notice that, if we use additional 2000 samples to estimate the adaptive learning rate, the convergence problems are avoided (Adam-2k); or, if we increase the value of eps, the convergence problems are also relieved (Adam-eps).

Therefore, we conjecture that the large variance in the early stage causes the convergence problem, and further propose Rectified Adam by analytically reducing the large variance. More details can be found in our paper.

Questions and Discussions

Do I need to tune learning rate?

Yes, the robustness of RAdam is not infinity. In our experiments, it works for a broader range of learning rates, but not all learning rates.

Notes on Transformer (more discussions can be found in our Transformer Clinic project)

Choice of the Original Transformer. We choose the original Transformer as our main study object because, without warmup, it suffers from the most serious convergence problems in our experiments. With such serious problems, our controlled experiments can better verify our hypothesis (i.e., we demonstrate that Adam-2k / Adam-eps can avoid spurious local optima by minimal changes).

Sensitivity. We observe that the Transformer is sensitive to the architecture configuration, despite its efficiency and effectiveness. For example, by changing the position of the layer norm, the model may / may not require the warmup to get a good performance. Intuitively, since the gradient of the attention layer could be more sparse and the adaptive learning rates for smaller gradients have a larger variance, they are more sensitive. Nevertheless, we believe this problem deserves more in-depth analysis and is beyond the scope of our study.

Why does warmup have a bigger impact on some models than others?

Although the adaptive learning rate has a larger variance in the early stage, the exact magnitude is subject to the model design. Thus, the convergent problem could be more serious for some models/tasks than others. In our experiments, we observe that RAdam achieves consistent improvements over the vanilla Adam. It verifies the variance issue widely exists (since we can get better performance by fixing it).

What if the gradient is not zero-meaned?

As in Figure 1 (above), even if the gradient is not zero-meaned, the original adaptive learning rate still has a larger variance in the beginning, thus applying the rectification can help to stabilize the training.

Another related concern is that, when the mean of the gradient is significantly larger than its variance, the magnitude of the "problematic" variance may not be very large (i.e., in Figure 1, when \mu equals to 10, the adaptive learning rate variance is relatively small and may not cause problems). We think it provides a possible explaination on why warmup have a bigger impact on some models than others. Still, we suggest that, in real-world applications, neural networks usually have some parts of parameters meet our assumption well (i.e., their gradient variance is larger than their gradient mean), and needs the rectification to stabilize the training.

Why does SGD need warmup?

To the best of our knowledge, the warmup heuristic is originally designed for large minibatch SGD [0], based on the intuition that the network changes rapidly in the early stage. However, we find that it does not explain why Adam requires warmup. Notice that, Adam-2k uses the same large learning rate but with a better estimation of the adaptive learning rate can also avoid the convergence problems.

The reason why sometimes warmup also helps SGD still lacks of theoretical support. FYI, when optimizing a simple 2-layer CNN with gradient descent, the thoery of [1] could be used to show the benifits of warmup. Specifically, the lr must be $O(cos \phi)$, where $\phi$ is the angle between the current weight and the ground true weight and $cos \phi$ could be very small due to high dimensional space and random initialization. And thus lr must be very small at the beginning to guarentee the convergence. $cos \phi$ however can be improved in the later stage, and thus the learning rate is allowed to be larger. Their theory somehow can justify why warmup is needed by gradient descend and neural networks. But it is still far-fetched for the real scenario.

[0] Goyal et al, Accurate, Large Minibatch SGD: Training Imagenet in 1 Hour, 2017

[1] Du et al, Gradient Descent Learns One-hidden-layer CNN: Don’t be Afraid of Spurious Local Minima, 2017

Quick Start Guide

  1. Directly replace the vanilla Adam with RAdam without changing any settings.
  2. Further tune hyper-parameters (including the learning rate) for a better performance.

Note that in our paper, our major contribution is to identify why we need the warmup for Adam. Although some researchers successfully improve their model performance (user comments), considering the difficulty of training NNs, directly plugging in RAdam may not result in an immediate performance boost. Based on our experience, replacing the vanilla Adam with RAdam usually results in a better performance; however, if warmup has already been employed and tuned in the baseline method, it is necessary to also tune hyper-parameters for RAdam.

Related Posts and Repos

Unofficial Re-Implementations

RAdam is very easy to implement, we provide PyTorch implementations here, while third party ones can be found at:

Keras Implementation

Keras Implementation

Julia implementation in Flux.jl

Unofficial Introduction & Mentions

We provide a simple introduction in Motivation, and more details can be found in our paper. There are some unofficial introductions available (with better writings), and they are listed here for reference only (contents/claims in our paper are more accurate):

Medium Post

related Twitter Post

CSDN Post (in Chinese)

User Comments

We are happy to see that our algorithms are found to be useful by some users : -)

"...I tested it on ImageNette and quickly got new high accuracy scores for the 5 and 20 epoch 128px leaderboard scores, so I know it works... https://forums.fast.ai/t/meet-radam-imo-the-new-state-of-the-art-ai-optimizer/52656

— Less Wright August 15, 2019

Thought "sounds interesting, I'll give it a try" - top 5 are vanilla Adam, bottom 4 (I only have access to 4 GPUs) are RAdam... so far looking pretty promising! pic.twitter.com/irvJSeoVfx

— Hamish Dickson (@_mishy) August 16, 2019

RAdam works great for me! It’s good to several % accuracy for free, but the biggest thing I like is the training stability. RAdam is way more stable! https://medium.com/@mgrankin/radam-works-great-for-me-344d37183943

— Grankin Mikhail August 17, 2019

"... Also, I achieved higher accuracy results using the newly proposed RAdam optimization function.... https://towardsdatascience.com/optimism-is-on-the-menu-a-recession-is-not-d87cce265b10

— Sameer Ahuja August 24, 2019

"... Out-of-box RAdam implementation performs better than Adam and finetuned SGD... https://twitter.com/ukrdailo/status/1166265186920980480

— Alex Dailo August 27, 2019

Citation

Please cite the following paper if you found our model useful. Thanks!

Liyuan Liu , Haoming Jiang, Pengcheng He, Weizhu Chen, Xiaodong Liu, Jianfeng Gao, and Jiawei Han (2020). On the Variance of the Adaptive Learning Rate and Beyond. the Eighth International Conference on Learning Representations.

@inproceedings{liu2019radam,
 author = {Liu, Liyuan and Jiang, Haoming and He, Pengcheng and Chen, Weizhu and Liu, Xiaodong and Gao, Jianfeng and Han, Jiawei},
 booktitle = {Proceedings of the Eighth International Conference on Learning Representations (ICLR 2020)},
 month = {April},
 title = {On the Variance of the Adaptive Learning Rate and Beyond},
 year = {2020}
}

Comments
  • RAdam Instability vs AdamW / Adam

    RAdam Instability vs AdamW / Adam

    Late to the party, but once again good work to you all @LiyuanLucasLiu !

    So I was testing RAdam vs AdamW on simple linear models [ie Logistic Regression / Linear Regression]. Obviously for these small problems, using new methods is a bit overdoing it, but trying them on small problems [Sklearn datasets like Boston, MNIST, Wine] is also important :)

    After finding the best LR using the Learning Range Finder (which turns out to be the same LR for both [0.046]) + using gradient centralization + batch size = 16, with careful bias intialization (mean(y)), RAdam does seem more "stable" than AdamW. image

    However, I noticed that if you do NOT standardize your data, RAdam's gradient diverges dramatically. The LR Range Test on NOT standardized data gave LR = 6.51e-05, which is super small. But, RAdam diverges. image

    AdamW [lr = 1e-3] also has higher error when not standardized, however, the loss doesn't diverge a lot. image

    I also tried before (p < 5), to manually clip gradients by dividing by its norm. It's now much closer to AdamW. image

    So my Q is: is this expected of RAdam to diverge if the dataset is not standardized? Should AdamW be used instead? Is it because of SGD + Momentum when (p < 5) that this divergement is seen?

    opened by danielhanchen 8
  • Does RAdam have a Keras version?

    Does RAdam have a Keras version?

    Hi, good job! You implemented RAdam in PyTorch version, is it possible to offer a Keras version later? Appreciate :) 你好,你是用PyTorch实现的RAdam,那么有没有计划提供一个Keras版本的RAdam呢?多谢啦

    opened by xingyi-li 8
  • How to choose decay rate? (No success with RAdam - does one need a decay scheduler or gradient clipping)

    How to choose decay rate? (No success with RAdam - does one need a decay scheduler or gradient clipping)

    Hi Liyuan,

    I've had difficulties with RAdam in my sequence learning problems. I am using a standard pytorch transformer with your library. I was wondering if one usually needs a scheduler (e.g. to reduce the learning rate at the end) besides RAdam or gradient clipping since my models seem to diverge often:

    Screen Shot 2021-08-05 at 10 37 03 AM

    opened by brando90 5
  • math.sqrt gets a negative argument

    math.sqrt gets a negative argument

    Hi! I have been trying to train the TransformerXL language model (https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_wt103_base.sh) with RAdam and I get *** ValueError: math domain error

    Traceback (most recent call last):
      File "train.py", line 543, in <module>
        train()
      File "train.py", line 463, in train
        optimizer.step()
      File "/transformer-xl/pytorch/radam.py", line 69, in step
        N_sma * N_sma_max / (N_sma_max - 2))) / beta1_t
    ValueError: math domain error
    

    this is because the argument to math.sqrt is negative here - https://github.com/LiyuanLucasLiu/RAdam/blob/master/language-model/model_word_ada/radam.py#L67

    What would be the right fix for this? I tried math.sqrt(abs()) but that performs worse than adam.

    opened by akhileshgotmare 5
  • Speed performance

    Speed performance

    Good day! Thanks for your work.

    Is RAdam more computationally effective than Adam? In my task setting RAdam makes much faster steps on the same batches and I'm trying to figure out why... image

    opened by ivanvovk 5
  • Theory question on warmup

    Theory question on warmup

    Due to the lack of samples in the early stage, the adaptive learning rate has an undesirably large variance, which leads to suspicious/bad local optima -- pg. 3

    Does this apply when feeding the same dataset in a different configuration? Namely, I'm training a timeseries (16-channel EEG) CNN-LSTM classifier, and vary the input timesteps across epochs for the same model. While the information source probability distribution remains identical, what the neural net effectively "sees" differs substantially between, say, 13500 and 216000 timesteps.

    This considered, is warmup for the first epoch of every new timesteps setting advisable? Thanks

    question 
    opened by OverLordGoldDragon 4
  • ResNet56

    ResNet56

    I am sorry, I come again. Can you tell me the hyperparameters setting on ResNet56? I got a very poor test accuracy 91.1 which is worse than that in ResNet20 . I set lr=0.01 and weight-decay=1e-4. Is there something wrong?

    opened by Slawlight 4
  • Does RAdam break training with different learning rates for different param_groups?

    Does RAdam break training with different learning rates for different param_groups?

    If I understand the source code for RAdam in radam.py correctly the global buffer will cache step_size parameters just depending on the state['step']. This would fail in training regimes where each param_group has its own learning rate as the buffer would contain a step_size based on the learning rate from the first processed group.

    opened by sholderbach 3
  • Deprecated Warning in `RAdam` with torch==1.7.1

    Deprecated Warning in `RAdam` with torch==1.7.1

    Hi @LiyuanLucasLiu , thanks for your incredible lib. With RAdam I got better performance without changing any hyperparameter. However, this a deprecated warning in RAdam with torch==1.7.1:

    UserWarning: This overload of addcmul_ is deprecated:
    	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
    Consider using one of the following signatures instead:
    	addcmul_(Tensor tensor1, Tensor tensor2, *, Number value) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:882.)
      exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
    

    according to the doc of addcmul_ in torch==1.7.1

    Docstring:
    addcmul_(tensor1, tensor2, *, value=1) -> Tensor
    
    In-place version of :meth:`~Tensor.addcmul`
    Type:      builtin_function_or_method
    

    so to adapt to 1.7.1 and disable this warning, I only need to change exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) to exp_avg_sq.mul_(beta2).addcmul_(grad, grad, 1 - beta2), right?

    opened by wenmin-wu 2
  • Any concern for using `math.sqrt` instead of `torch.sqrt`

    Any concern for using `math.sqrt` instead of `torch.sqrt`

    I find you use a lot of math.sqrt in your implementation. Any concern for not using torch.sqrt instead? I think math.sqrt is slower than torch.sqrt because it's on CPU.

    opened by wenmin-wu 2
  • distributed training generating

    distributed training generating "exp_avg error"

    Hi, Thanks for sharing the code. I have test on the single node programming and it works.

    However, when I use distributed training in Pytorch, it saids:

    line 39, in step state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) KeyError: 'exp_avg'

    Any suggestions towards this?

    Much appreciated for any comments!

    bug 
    opened by h-jia 2
  • Are the plots you have wrt epochs or iterations?

    Are the plots you have wrt epochs or iterations?

    For example figure 1:

    radam fig1

    in general, I am trying to figure out if in general people train transformers wrt epochs or iterations (1 iteration is one batch).

    opened by brando90 1
  • NaNs

    NaNs

    I observed that the RAdam method can start at first epochs to be produce NaN Loss while Adams not. It's not only for one or two experiments but a general observation. I wonder if we can merge Adabound clamp to RAdam to avoid this type of issue in the very beginning of the training ?

    opened by thegodone 1
  • simplify add_

    simplify add_

    Hi,

    I have a small optimization to suggest:

    Is there any particular reason to not simplify

    [line 84] p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)

    into

    p_data_fp32.mul_(-group['weight_decay'] * group['lr'])

    ? Other lines could be simplified the same manner.

    opened by LucasMourot 0
  • Will radam be affacted by weight decay?

    Will radam be affacted by weight decay?

    Hi,

    It is said that naive adam will make performance bad if weight decay is added. Thus people invented adamW to make adam compatible with weight decay. Now I have a question, does radam work well if I use it together with weight decay ?

    opened by CoinCheung 0
Owner
Liyuan Liu
Ph.D. Student @ DMG, UIUC
Liyuan Liu
A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch

Torchmeta A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch. Torchmeta contains popular meta-learning bench

Tristan Deleu 1.7k Jan 6, 2023
The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.

News March 3: v0.9.97 has various bug fixes and improvements: Bug fixes for NTXentLoss Efficiency improvement for AccuracyCalculator, by using torch i

Kevin Musgrave 5k Jan 2, 2023
Tez is a super-simple and lightweight Trainer for PyTorch. It also comes with many utils that you can use to tackle over 90% of deep learning projects in PyTorch.

Tez: a simple pytorch trainer NOTE: Currently, we are not accepting any pull requests! All PRs will be closed. If you want a feature or something does

abhishek thakur 1.1k Jan 4, 2023
PyGCL: Graph Contrastive Learning Library for PyTorch

PyGCL is an open-source library for graph contrastive learning (GCL), which features modularized GCL components from published papers, standardized evaluation, and experiment management.

GCL: Graph Contrastive Learning Library for PyTorch 592 Jan 7, 2023
A PyTorch implementation of Learning to learn by gradient descent by gradient descent

Intro PyTorch implementation of Learning to learn by gradient descent by gradient descent. Run python main.py TODO Initial implementation Toy data LST

Ilya Kostrikov 300 Dec 11, 2022
Tutorial for surrogate gradient learning in spiking neural networks

SpyTorch A tutorial on surrogate gradient learning in spiking neural networks Version: 0.4 This repository contains tutorial files to get you started

Friedemann Zenke 203 Nov 28, 2022
A tutorial on "Bayesian Compression for Deep Learning" published at NIPS (2017).

Code release for "Bayesian Compression for Deep Learning" In "Bayesian Compression for Deep Learning" we adopt a Bayesian view for the compression of

Karen Ullrich 190 Dec 30, 2022
Learning Sparse Neural Networks through L0 regularization

Example implementation of the L0 regularization method described at Learning Sparse Neural Networks through L0 regularization, Christos Louizos, Max W

AMLAB 202 Nov 10, 2022
PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

Cong Cai 12 Dec 19, 2021
Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.

PyTorch Implementation of Differentiable ODE Solvers This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpr

Ricky Chen 4.4k Jan 4, 2023
Pretrained EfficientNet, EfficientNet-Lite, MixNet, MobileNetV3 / V2, MNASNet A1 and B1, FBNet, Single-Path NAS

(Generic) EfficientNets for PyTorch A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter ef

Ross Wightman 1.5k Jan 1, 2023
PyTorch extensions for fast R&D prototyping and Kaggle farming

Pytorch-toolbelt A pytorch-toolbelt is a Python library with a set of bells and whistles for PyTorch for fast R&D prototyping and Kaggle farming: What

Eugene Khvedchenya 1.3k Jan 5, 2023
Fast, general, and tested differentiable structured prediction in PyTorch

Torch-Struct: Structured Prediction Library A library of tested, GPU implementations of core structured prediction algorithms for deep learning applic

HNLP 1.1k Jan 7, 2023
Differentiable SDE solvers with GPU support and efficient sensitivity analysis.

PyTorch Implementation of Differentiable SDE Solvers This library provides stochastic differential equation (SDE) solvers with GPU support and efficie

Google Research 1.2k Jan 4, 2023
A tiny scalar-valued autograd engine and a neural net library on top of it with PyTorch-like API

micrograd A tiny Autograd engine (with a bite! :)). Implements backpropagation (reverse-mode autodiff) over a dynamically built DAG and a small neural

Andrej 3.5k Jan 8, 2023
A simplified framework and utilities for PyTorch

Here is Poutyne. Poutyne is a simplified framework for PyTorch and handles much of the boilerplating code needed to train neural networks. Use Poutyne

GRAAL/GRAIL 534 Dec 17, 2022
An optimizer that trains as fast as Adam and as good as SGD.

AdaBound An optimizer that trains as fast as Adam and as good as SGD, for developing state-of-the-art deep learning models on a wide variety of popula

LoLo 2.9k Dec 27, 2022
pip install antialiased-cnns to improve stability and accuracy

Antialiased CNNs [Project Page] [Paper] [Talk] Making Convolutional Networks Shift-Invariant Again Richard Zhang. In ICML, 2019. Quick & easy start Ru

Adobe, Inc. 1.6k Dec 28, 2022
A simple way to train and use PyTorch models with multi-GPU, TPU, mixed-precision

?? Accelerate was created for PyTorch users who like to write the training loop of PyTorch models but are reluctant to write and maintain the boilerplate code needed to use multi-GPUs/TPU/fp16.

Hugging Face 3.5k Jan 8, 2023