Adabelief-Optimizer - Repository for NeurIPS 2020 Spotlight "AdaBelief Optimizer: Adapting stepsizes by the belief in observed gradients"

Overview

AdaBelief Optimizer

NeurIPS 2020 Spotlight, trains fast as Adam, generalizes well as SGD, and is stable to train GANs.

Release of package

We have released adabelief-pytorch==0.2.0 and adabelief-tf==0.2.0. Please use the latest version from pip. Source code is available under folder pypi_packages/adabelief_pytorch0.2.0 and pypi_packages/adabelief_tf0.2.0.

Table of Contents

External Links

Project Page, arXiv , Reddit , Twitter, BiliBili (中文), BiliBili (Engligh), Youtube

Link to code for extra experiments with AdaBelief

Update for adabelief-pytorch==0.2.0 (Crucial)

In the next release of adabelief-pytorch, we will modify the default of several arguments, in order to fit the needs of for general tasks such as GAN and Transformer. Please check if you specify these arguments or use the default when upgrade from version 0.0.5 to higher.

Version epsilon weight_decouple rectify
adabelief-pytorch=0.0.5 1e-8 False False
latest version 0.2.0>0.0.5 1e-16 True True

Update for adabelief-tf==0.2.0 (Crucial)

In adabelief-tf==0.1.0, we modify adabelief-tf to have the same feature as adabelief-pytorch, inlcuding decoupled weight decay and learning rate rectification. Furthermore, we will add support for TensorFlow>=2.0 and Keras. The source code is in pypi_packages/adabelief_tf0.1.0. We tested with a text classification task and a word embedding task. The default value is updated, please check if you specify these arguments or use the default when upgrade from version 0.0.1 to higher.:

Version epsilon weight_decouple rectify
adabelief-tf=0.0.1 1e-8 Not supported Not supported
latest version 0.2.0>0.0.1 1e-14 Supported (Not an option in arguments) default: True

Quick Guide

  • Check if the code is from the latest official implementation (adabelief-pytorch==0.1.0, adabelief-tf==0.1.0) Default hyper-parameters are different from the old version.

  • check all hyper-parameters, DO NOT simply use the default,

    Epsilon in AdaBelief is different from Adam (typically eps_adabelief = eps_adam*eps_adam)
    ( eps of Adam in Tensorflow is 1e-7, in PyTorch is 1e-8, need to consider this when use AdaBelief in Tensorflow)

    If SGD is better than Adam -> Set a large eps (1e-8) in AdaBelief-pytorch (1e-7 in Tensorflow )
    If SGD is worse than Adam -> Set a small eps (1e-16) in AdaBelief-pytorch (1e-14 in Tensorflow, rectify=True often helps)

    If AdamW is better than Adam -> Turn on “weight_decouple” in AdaBelief-pytorch (this is on in adabelief-tf==0.1.0 and cannot shut down).
    Note that default weight decay is very different for Adam and AdamW, you might need to consider this when using AdaBelief with and without decoupled weight decay.

  • Check ALL hyper-parameters. Refer to our github page for a list of recommended hyper-parameters

Table of Hyper-parameters

Please check if you have specify all arguments and check your version is latest, the default might not be suitable for different tasks, see tables below

Hyper-parameters in PyTorch

  • Note weight decay varies with tasks, for different tasks the weight decay is untuned from the original repository (only changed the optimizer and other hyper-parameters).
Task lr beta1 beta2 epsilon weight_decay weight_decouple rectify fixed_decay amsgrad
Cifar 1e-3 0.9 0.999 1e-8 5e-4 False False False False
ImageNet 1e-3 0.9 0.999 1e-8 1e-2 True False False False
Object detection (PASCAL) 1e-4 0.9 0.999 1e-8 1e-4 False False False False
LSTM-1layer 1e-3 0.9 0.999 1e-16 1.2e-6 False False False False
LSTM 2,3 layer 1e-2 0.9 0.999 1e-12 1.2e-6. False False False False
GAN (small) 2e-4 0.5 0.999 1e-12 0 True=False (decay=0) False False False
SN-GAN (large) 2e-4 0.5 0.999 1e-16 0 True=False (decay=0) True False False
Transformer 5e-4 0.9 0.999 1e-16 1e-4 True True False False
Reinforcement (Rainbow) 1e-4 0.9 0.999 1e-10 0.0 True=False (decay=0) True False False
Reinforcement (HalfCheetah-v2) 1e-3 0.9 0.999 1e-12 0.0 True=False (decay=0) True False False

Hyper-parameters in Tensorflow (eps in Tensorflow might need to be larger than in PyTorch)

epsilon is used in a different way in Tensorflow (default 1e-7) compared to PyTorch (default 1e-8), so eps in Tensorflow might needs to be larger than in PyTorch (perhaps 100 times larger in Tensorflow, e.g. eps=1e-16 in PyTorch v.s eps=1e-14 in Tensorflow). But personally I don't have much experience with Tensorflow, it's likely that you need to slightly tune eps.

Installation and usage

1. PyTorch implementations

( Results in the paper are all generated using the PyTorch implementation in adabelief-pytorch package, which is the ONLY package that I have extensively tested for now.)

AdaBelief

Please install latest version (0.2.0), previous version (0.0.5) uses different default arguments.

pip install adabelief-pytorch==0.2.0
from adabelief_pytorch import AdaBelief
optimizer = AdaBelief(model.parameters(), lr=1e-3, eps=1e-16, betas=(0.9,0.999), weight_decouple = True, rectify = False)

Adabelief with Ranger optimizer

pip install ranger-adabelief==0.1.0
from ranger_adabelief import RangerAdaBelief
optimizer = RangerAdaBelief(model.parameters(), lr=1e-3, eps=1e-12, betas=(0.9,0.999))

2. Tensorflow implementation (eps of AdaBelief in Tensorflow is larger than in PyTorch, same for Adam)

pip install adabelief-tf==0.2.0
from adabelief_tf import AdaBeliefOptimizer
optimizer = AdaBeliefOptimizer(learning_rate=1e-3, epsilon=1e-14, rectify=False)

A quick look at the algorithm

Adam and AdaBelief are summarized in Algo.1 and Algo.2, where all operations are element-wise, with differences marked in blue. Note that no extra parameters are introduced in AdaBelief. For simplicity, we omit the bias correction step. Specifically, in Adam, the update direction is , where is the EMA (Exponential Moving Average) of ; in AdaBelief, the update direction is , where is the of . Intuitively, viewing as the prediction of , AdaBelief takes a large step when observation is close to prediction , and a small step when the observation greatly deviates from the prediction.

Reproduce results in the paper

(Comparison with 8 other optimizers: SGD, Adam, AdaBound, RAdam, AdamW, Yogi, MSVAG, Fromage)

See folder PyTorch_Experiments, for each subfolder, execute sh run.sh. See readme.txt in each subfolder for visualization, or refer to jupyter notebook for visualization.

Results on Image Recognition

Results on GAN training

Results on a small GAN with vanilla CNN generator

Results on Spectral Normalization GAN with a ResNet generator

Results on LSTM

Results on Transformer

Results on Toy Example

Discussions

Installation

Please install the latest version from pip, old versions might suffer from bugs. Source code for up-to-date package is available in folder pypi_packages.

Discussion on hyper-parameters

AdaBelief uses a different denominator from Adam, and is orthogonal to other techniques such as recification, decoupled weight decay, weight averaging et.al. This implies when you use some techniques with Adam, to get a good result with AdaBelief you might still need those techniques.

  • epsilon in AdaBelief plays a different role as in Adam, typically when you use epslison=x in Adam, using epsilon=x*x will give similar results in AdaBelief. The default value epsilon=1e-8 is not a good option in many cases, in version >0.1.0 the default eps is set as 1e-16.

  • If you task needs a "non-adaptive" optimizer, which means SGD performs much better than Adam(W), such as on image recognition, you need to set a large epsilon(e.g. 1e-8) for AdaBelief to make it more non-adaptive; if your task needs a really adaptive optimizer, which means Adam is much better than SGD, such as GAN and Transformer, then the recommended epsilon for AdaBelief is small (1e-12, 1e-16 ...).

  • If decoupled weight decay is very important for your task, which means AdamW is much better than Adam, then you need to set weight_decouple as True to turn on decoupled decay in AdaBelief. Note that many optimizers uses decoupled weight decay without specifying it as an options, e.g. RAdam, but we provide it as an option so users are aware of what technique is actually used.

  • Don't use "gradient threshold" (clamp each element independently) in AdaBelief, it could result in division by 0 and explosion in update; but "gradient clip" (shrink amplitude of the gradient vector but keeps its direction) is fine, though from my limited experience sometimes the clip range needs to be the same or larger than Adam.

Discussion on algorithms

1. Weight Decay:
  • Decoupling (argument weight_decouple appears in AdaBelief and RangerAdaBelief):
    Currently there are two ways to perform weight decay for adaptive optimizers, directly apply it to the gradient (Adam), or decouple weight decay from gradient descent (AdamW). This is passed to the optimizer by argument weight_decouple (default: False).

  • Fixed ratio (argument fixed_decay (default: False) appears in AdaBelief):
    (1) If weight_decouple == False, then this argument does not affect optimization.
    (2) If weight_decouple == True:

      If fixed_decay == False, the weight is multiplied by 1 -lr x weight_decay
      If fixed_decay == True, the weight is multiplied by 1 - weight_decay. This is implemented as an option but not used to produce results in the paper.

  • What is the acutal weight-decay we are using?
    This is seldom discussed in the literature, but personally I think it's very important. When we set weight_decay=1e-4 for SGD, the weight is scaled by 1 - lr x weight_decay. Two points need to be emphasized: (1) lr in SGD is typically larger than Adam (0.1 vs 0.001), so the weight decay in Adam needs to be set as a larger number to compensate. (2) lr decays, this means typically we use a larger weight decay in early phases, and use a small weight decay in late phases.

2. Epsilon:

AdaBelief seems to require a different epsilon from Adam. In CV tasks in this paper, epsilon is set as 1e-8. For GAN training it's set as 1e-16. We recommend try different epsilon values in practice, and sweep through a large region. We recommend use eps=1e-8 when SGD outperforms Adam, such as many CV tasks; recommend eps=1e-16 when Adam outperforms SGD, such as GAN and Transformer. Sometimes you might need to try eps=1e-12, such as in some reinforcement learning tasks.

3. Rectify (argument rectify in AdaBelief):

Whether to turn on the rectification as in RAdam. The recitification basically uses SGD in early phases for warmup, then switch to Adam. Rectification is implemented as an option, but is never used to produce results in the paper.

4. AMSgrad (argument amsgrad (default: False) in AdaBelief):

Whether to take the max (over history) of denominator, same as AMSGrad. It's set as False for all experiments.

5. Details to reproduce results
  • Results in the paper are generated using the PyTorch implementation in adabelief-pytorch package. This is the ONLY package that I have extensively tested for now.
  • We also provide a modification of ranger optimizer in ranger-adabelief which combines RAdam + LookAhead + Gradient Centralization + AdaBelief, but this is not used in the paper and is not extensively tested.
  • The adabelief-tf is a naive implementation in Tensorflow. It lacks many features such as decoupled weight decay, and is not extensively tested. Currently I don't have plans to improve it since I seldom use Tensorflow, please contact me if you want to collaborate and improve it.
  • The adabelief-tf==0.1.0 supports the same feature as adabelief-pytorch==0.1.0, including decoupled weight decay and rectification. But personally I don't have the chance to perform extensive tests as with the PyTorch version.
6. Learning rate schedule

The experiments on Cifar is the same as demo in AdaBound, with the only difference is the optimizer. The ImageNet experiment uses a different learning rate schedule, typically is decayed by 1/10 at epoch 30, 60, and ends at 90. For some reasons I have not extensively experimented, AdaBelief performs good when decayed at epoch 70, 80 and ends at 90, using the default lr schedule produces a slightly worse result. If you have any ideas on this please open an issue here or email me.

7. Some experience with RNN

I got some feedbacks on RNN on reddit discussion, here are a few tips:

  • The epsilon is suggested to set as a smaller value for RNN (e.g. 1e-12, 1e-16). Please try different epsilon values, it varies from task to task.
  • I might confuse "gradient threshold" with "gradient clip" in previous readme, clarify below:
    (1) By "gradient threshold" I refer to element-wise operation, which only takes values between a certain region [a,b]. Values outside this region will be set as a and b respectively.
    (2) By "gradient clip" I refer to the operation on a vector or tensor. Suppose X is a tensor, if ||X|| > thres, then X <- X/||X|| * thres. Take X as a vector, "gradient clip" shrinks the amplitude but keeps the direction.
    (3) "Gradient threshold" is incompatible with AdaBelief, because if gt is thresholded for a long time, then |gt-mt|~=0, and the division will explode; however, "gradient clip" is fine for Adabelief, yet the clip range still needs tuning (perhaps AdaBelief needs a larger range than Adam).
8. Contact

Please contact me at [email protected] or open an issue here if you would like to help improve it, especially the tensorflow version, or explore combination with other methods, some discussion on the theory part, or combination with other methods to create a better optimizer. Any thoughts are welcome!

Update Plan

To do

Done

  • Updated results on an SN-GAN is in https://github.com/juntang-zhuang/SNGAN-AdaBelief, AdaBelief achieves 12.36 FID (lower is better) on Cifar10, while Adam achieves 13.25 (number taken from the log of official repository PyTorch-studioGAN).
  • LSTM experiments uploaded to PyTorch_Experiments/LSTM
  • Identify the problem of Transformer with PyTorch 1.4, to be an old version fairseq is incompatible with new version PyTorch, works fine with latest fairseq.
    Code on Transformer to work with PyTorch 1.6 is at: https://github.com/juntang-zhuang/fairseq-adabelief
    Code for transformer to work with PyTorch 1.1 and CUDA9.0 is at: https://github.com/juntang-zhuang/transformer-adabelief
  • Tested on a toy example of reinforcement learning.
  • Released adabelief-pytorch==0.1.0 and adabelief-tf==0.1.0. The Tensorflow version now supports TF>=2.0 and Keras, with the same features as in the PyTorch version, including decoupled weight decay and rectification.
  • Released adabelief-pytorch==0.2.0. Fix the error with coupled weight decay in adabelief-pytorch==0.1.0, fix the amsgrad update in adabelief-pytorch==0.1.0. Add options to disable the message printing, by specify print_change_log=False when initiating the optimizer.
  • Released adabelief-tf==0.2.0. Add options to disable the message printing, by specify print_change_log=False when initiating the optimizer. Delte redundant computations, so 0.2.0 should be faster than 0.1.0. Removed dependencies on tensorflow-addons.
  • adabelief-pytorch==0.2.1 is compatible with mixed-precision training.

Citation

@article{zhuang2020adabelief,
  title={AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients},
  author={Zhuang, Juntang and Tang, Tommy and Ding, Yifan and Tatikonda, Sekhar and Dvornek, Nicha and Papademetris, Xenophon and Duncan, James},
  journal={Conference on Neural Information Processing Systems},
  year={2020}
}
Comments
  • Tensorflow Implementation

    Tensorflow Implementation

    When I tried the optimizer with tensorflow cycle GAN, it takes lot of time to complete one step. Is it a problem regarding the use of gpu or framework, or with the optimizer itself?

    Thanks in Advance

    opened by ManoharSai2000 14
  • Do the weight decay before using grad

    Do the weight decay before using grad

    Uncoupled weight decay was done after gradient is used to calculate momentum and variance. Fixed it. Found this while writing a tutorial implementation.

    opened by vpj 13
  • Results on ImageNet with tuning weight decay

    Results on ImageNet with tuning weight decay

    I quickly run some experiments on ImageNet with different weight decay rates.

    Using AdamW with wd=1e-2 and setting other hyper parameters the same as reported in AdaBelief paper, the average accuracy over 3 runs is 69.73%, much better than that compared in the paper. I will keep updating results for other optimizers and weight decay rates.

    opened by XuezheMax 11
  • Different usage of eps between

    Different usage of eps between "A quick look at the algorithm" and the code

    Hi

    I have a question.

    In "A quick look at the algorithm" in README.md, eps is added to shat_t. But in the code, eps is added to s_t(exp_avg_var) instead of shat_t.

    Also only the code for pytorch, if amsgrad is True, eps is added to max_exp_avg_var instead of exp_arg_var(s_t) or shat_t.

    Which behavior is correct?

    pytorch code tensorflow code

    opened by tatsuhiko-inoue 10
  • Matlab implementation

    Matlab implementation

    How about a Matlab test case?

    I tried to implement a Matlab version of AdaBelief and compare it with SGD with momentum at https://github.com/pcwhy/AdaBelief-Matlab I found that sometimes AdaBelief is not guaranteed to converge to an optimal solution as SGD with momentum can reach.

    opened by pcwhy 8
  • Fix problem with sparse layers in tf0.1.0

    Fix problem with sparse layers in tf0.1.0

    The _resource_apply_sparse function applies update according to the indices. Gathering the elements before update would fix the error caused by ResourceScatterAdd. The fix has been tested on word embeddings.

    opened by cryu854 8
  • issues on AdaBlief-tensorflow

    issues on AdaBlief-tensorflow

    HI! I had some trouble using Adambelief in a simple lstm training. What could be the reason for this? CODE: from adabelief_tf import AdaBeliefOptimizer tf.keras.backend.clear_session() multivariate_lstmA = tf.keras.models.Sequential([ LSTM(100, input_shape=input_shape, return_sequences=True), Flatten(), Dense(200, activation='relu'), Dropout(0.1), Dense(1) ]) model_checkpoint = tf.keras.callbacks.ModelCheckpoint( 'multivariate_lstmA.h5', monitor=('val_loss'), save_best_only=True) optimizer = AdaBeliefOptimizer(learning_rate=1e-3, epsilon=1e-14, rectify=False) multivariate_lstmA.compile(loss=loss, optimizer=optimizer, metrics=metric)

    RESULT: Please check your arguments if you have upgraded adabelief-tf from version 0.0.1. Modifications to default arguments: eps weight_decouple rectify


    adabelief-tf=0.0.1 1e-08 Not supported Not supported Current version (0.1.0) 1e-14 supported default: True For a complete table of recommended hyperparameters, see https://github.com/juntang-zhuang/Adabelief-Optimizer

    opened by dusk666 7
  • Similarity to AdaHessian

    Similarity to AdaHessian

    Hi, first of all, thank you very much for sharing the code for AdaBelief, it looks like a very promising optimizer! :) Have you considered comparing it to AdaHessian? I feel like AdaHessian is using the same trick as you (but they do it less efficiently).

    opened by davda54 7
  • Unstability in training in RNN

    Unstability in training in RNN

    Hello,

    Congratulations about this awesome paper and for providing the code to test it. I´m training a small RNN network ( 2 layers of SRU (https://github.com/asappresearch/sru), 256 hidden size, CRF at end) for the NER task.

    As following the Readme, I disabled the gradient clipping, and used an epsilon of 1e-12. This task converges great with Ranger, SGD and Adam. But using Adabelief I get some loss explosion randomly.

    Am I doing something wrong?

    accuracy: 0.8366, accuracy3: 0.8366, precision-overall: 0.0040, recall-overall: 0.0163, f1-measure-overall: 0.0065, batch_loss: 7236.0938, loss: 57461.7845 ||: : 30it [09:29, 18.99s/it]                        
    accuracy: 0.9254, accuracy3: 0.9255, precision-overall: 0.1612, recall-overall: 0.2104, f1-measure-overall: 0.1825, batch_loss: 51126.7266, loss: 18637.9896 ||: : 30it [08:47, 17.60s/it]                       
    accuracy: 0.9645, accuracy3: 0.9645, precision-overall: 0.3207, recall-overall: 0.4666, f1-measure-overall: 0.3801, batch_loss: 11046.6484, loss: 13583.7611 ||: : 30it [08:59, 17.99s/it]                      
    accuracy: 0.9828, accuracy3: 0.9829, precision-overall: 0.6505, recall-overall: 0.7602, f1-measure-overall: 0.7011, batch_loss: 8434.5000, loss: 3932.2246 ||: : 29it [08:37, 17.86s/it]                       
    accuracy: 0.9856, accuracy3: 0.9856, precision-overall: 0.7832, recall-overall: 0.8383, f1-measure-overall: 0.8098, batch_loss: 122.3125, loss: 3008.3288 ||: : 29it [09:13, 19.09s/it]                        
    accuracy: 0.9930, accuracy3: 0.9930, precision-overall: 0.8261, recall-overall: 0.8861, f1-measure-overall: 0.8551, batch_loss: 2115.6699, loss: 1362.0373 ||: : 30it [08:55, 17.84s/it]                       
    accuracy: 0.9948, accuracy3: 0.9948, precision-overall: 0.8893, recall-overall: 0.9243, f1-measure-overall: 0.9065, batch_loss: 1569.0469, loss: 1011.7590 ||: : 30it [08:33, 17.10s/it]                       
    accuracy: 0.9972, accuracy3: 0.9972, precision-overall: 0.9367, recall-overall: 0.9571, f1-measure-overall: 0.9468, batch_loss: 591.5840, loss: 426.5681 ||: : 29it [08:58, 18.56s/it]                       
    accuracy: 0.9977, accuracy3: 0.9977, precision-overall: 0.9514, recall-overall: 0.9660, f1-measure-overall: 0.9587, batch_loss: 23.7188, loss: 279.9471 ||: : 29it [08:32, 17.69s/it]                        
    accuracy: 0.9977, accuracy3: 0.9977, precision-overall: 0.9501, recall-overall: 0.9627, f1-measure-overall: 0.9564, batch_loss: 93.2188, loss: 243.8314 ||: : 30it [09:16, 18.54s/it]                        
    accuracy: 0.9984, accuracy3: 0.9984, precision-overall: 0.9641, recall-overall: 0.9732, f1-measure-overall: 0.9686, batch_loss: 53.5000, loss: 199.5779 ||: : 29it [08:44, 18.10s/it]                        
    accuracy: 0.9984, accuracy3: 0.9984, precision-overall: 0.9702, recall-overall: 0.9789, f1-measure-overall: 0.9745, batch_loss: 52.5781, loss: 156.1823 ||: : 30it [09:14, 18.47s/it]                       
    accuracy: 0.9994, accuracy3: 0.9994, precision-overall: 0.9816, recall-overall: 0.9871, f1-measure-overall: 0.9843, batch_loss: 61.4688, loss: 69.1954 ||: : 29it [09:01, 18.66s/it]                        
    accuracy: 0.9990, accuracy3: 0.9990, precision-overall: 0.9813, recall-overall: 0.9858, f1-measure-overall: 0.9836, batch_loss: 29.5312, loss: 90.0869 ||: : 29it [08:51, 18.33s/it]                        
    accuracy: 0.9996, accuracy3: 0.9996, precision-overall: 0.9846, recall-overall: 0.9896, f1-measure-overall: 0.9871, batch_loss: 74.0625, loss: 53.9213 ||: : 29it [08:40, 17.94s/it]                       
    accuracy: 0.9995, accuracy3: 0.9995, precision-overall: 0.9822, recall-overall: 0.9868, f1-measure-overall: 0.9845, batch_loss: 33.9844, loss: 49.5508 ||: : 30it [08:35, 17.19s/it]                       
    accuracy: 0.9997, accuracy3: 0.9997, precision-overall: 0.9854, recall-overall: 0.9869, f1-measure-overall: 0.9862, batch_loss: 19.3906, loss: 34.1199 ||: : 30it [09:03, 18.11s/it]                       
    accuracy: 0.9995, accuracy3: 0.9995, precision-overall: 0.9938, recall-overall: 0.9950, f1-measure-overall: 0.9944, batch_loss: 709.4336, loss: 48.0945 ||: : 29it [08:38, 17.88s/it]                      
    accuracy: 0.9997, accuracy3: 0.9997, precision-overall: 0.9914, recall-overall: 0.9937, f1-measure-overall: 0.9925, batch_loss: 14.9688, loss: 38.2326 ||: : 29it [08:36, 17.79s/it]                       
    accuracy: 0.9996, accuracy3: 0.9996, precision-overall: 0.9852, recall-overall: 0.9894, f1-measure-overall: 0.9873, batch_loss: 79.4688, loss: 51.3397 ||: : 29it [08:55, 18.46s/it]                       
    accuracy: 0.9998, accuracy3: 0.9998, precision-overall: 0.9926, recall-overall: 0.9936, f1-measure-overall: 0.9931, batch_loss: 39.0625, loss: 22.0619 ||: : 30it [09:00, 18.03s/it]                      
    accuracy: 0.9997, accuracy3: 0.9997, precision-overall: 0.9915, recall-overall: 0.9937, f1-measure-overall: 0.9926, batch_loss: 16.9062, loss: 33.6324 ||: : 30it [09:32, 19.07s/it]                       
    accuracy: 0.9997, accuracy3: 0.9997, precision-overall: 0.9939, recall-overall: 0.9947, f1-measure-overall: 0.9943, batch_loss: 0.7812, loss: 27.4840 ||: : 30it [09:13, 18.44s/it]                        
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9951, recall-overall: 0.9959, f1-measure-overall: 0.9955, batch_loss: 27.0786, loss: 15.0342 ||: : 29it [09:08, 18.92s/it]                      
    accuracy: 0.9996, accuracy3: 0.9996, precision-overall: 0.9938, recall-overall: 0.9963, f1-measure-overall: 0.9951, batch_loss: 7.7500, loss: 25.8246 ||: : 29it [09:00, 18.63s/it]                       
    accuracy: 0.9998, accuracy3: 0.9998, precision-overall: 0.9957, recall-overall: 0.9966, f1-measure-overall: 0.9961, batch_loss: 27.6875, loss: 17.3096 ||: : 30it [08:47, 17.58s/it]                      
    accuracy: 0.9997, accuracy3: 0.9997, precision-overall: 0.9949, recall-overall: 0.9968, f1-measure-overall: 0.9958, batch_loss: 35.4727, loss: 26.2837 ||: : 29it [08:24, 17.40s/it]                      
    accuracy: 0.9977, accuracy3: 0.9977, precision-overall: 0.9501, recall-overall: 0.9627, f1-measure-overall: 0.9564, batch_loss: 93.2188, loss: 243.8314 ||: : 30it [09:16, 18.54s/it]
    accuracy: 0.9984, accuracy3: 0.9984, precision-overall: 0.9641, recall-overall: 0.9732, f1-measure-overall: 0.9686, batch_loss: 53.5000, loss: 199.5779 ||: : 29it [08:44, 18.10s/it]
    accuracy: 0.9984, accuracy3: 0.9984, precision-overall: 0.9702, recall-overall: 0.9789, f1-measure-overall: 0.9745, batch_loss: 52.5781, loss: 156.1823 ||: : 30it [09:14, 18.47s/it]
    accuracy: 0.9994, accuracy3: 0.9994, precision-overall: 0.9816, recall-overall: 0.9871, f1-measure-overall: 0.9843, batch_loss: 61.4688, loss: 69.1954 ||: : 29it [09:01, 18.66s/it]
    accuracy: 0.9990, accuracy3: 0.9990, precision-overall: 0.9813, recall-overall: 0.9858, f1-measure-overall: 0.9836, batch_loss: 29.5312, loss: 90.0869 ||: : 29it [08:51, 18.33s/it]
    accuracy: 0.9996, accuracy3: 0.9996, precision-overall: 0.9846, recall-overall: 0.9896, f1-measure-overall: 0.9871, batch_loss: 74.0625, loss: 53.9213 ||: : 29it [08:40, 17.94s/it]
    accuracy: 0.9995, accuracy3: 0.9995, precision-overall: 0.9822, recall-overall: 0.9868, f1-measure-overall: 0.9845, batch_loss: 33.9844, loss: 49.5508 ||: : 30it [08:35, 17.19s/it]
    accuracy: 0.9997, accuracy3: 0.9997, precision-overall: 0.9854, recall-overall: 0.9869, f1-measure-overall: 0.9862, batch_loss: 19.3906, loss: 34.1199 ||: : 30it [09:03, 18.11s/it]
    accuracy: 0.9995, accuracy3: 0.9995, precision-overall: 0.9938, recall-overall: 0.9950, f1-measure-overall: 0.9944, batch_loss: 709.4336, loss: 48.0945 ||: : 29it [08:38, 17.88s/it]
    accuracy: 0.9997, accuracy3: 0.9997, precision-overall: 0.9914, recall-overall: 0.9937, f1-measure-overall: 0.9925, batch_loss: 14.9688, loss: 38.2326 ||: : 29it [08:36, 17.79s/it]
    accuracy: 0.9996, accuracy3: 0.9996, precision-overall: 0.9852, recall-overall: 0.9894, f1-measure-overall: 0.9873, batch_loss: 79.4688, loss: 51.3397 ||: : 29it [08:55, 18.46s/it]
    accuracy: 0.9998, accuracy3: 0.9998, precision-overall: 0.9926, recall-overall: 0.9936, f1-measure-overall: 0.9931, batch_loss: 39.0625, loss: 22.0619 ||: : 30it [09:00, 18.03s/it]
    accuracy: 0.9997, accuracy3: 0.9997, precision-overall: 0.9915, recall-overall: 0.9937, f1-measure-overall: 0.9926, batch_loss: 16.9062, loss: 33.6324 ||: : 30it [09:32, 19.07s/it]
    accuracy: 0.9997, accuracy3: 0.9997, precision-overall: 0.9939, recall-overall: 0.9947, f1-measure-overall: 0.9943, batch_loss: 0.7812, loss: 27.4840 ||: : 30it [09:13, 18.44s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9951, recall-overall: 0.9959, f1-measure-overall: 0.9955, batch_loss: 27.0786, loss: 15.0342 ||: : 29it [09:08, 18.92s/it]
    accuracy: 0.9996, accuracy3: 0.9996, precision-overall: 0.9938, recall-overall: 0.9963, f1-measure-overall: 0.9951, batch_loss: 7.7500, loss: 25.8246 ||: : 29it [09:00, 18.63s/it]
    accuracy: 0.9998, accuracy3: 0.9998, precision-overall: 0.9957, recall-overall: 0.9966, f1-measure-overall: 0.9961, batch_loss: 27.6875, loss: 17.3096 ||: : 30it [08:47, 17.58s/it]
    accuracy: 0.9997, accuracy3: 0.9997, precision-overall: 0.9949, recall-overall: 0.9968, f1-measure-overall: 0.9958, batch_loss: 35.4727, loss: 26.2837 ||: : 29it [08:24, 17.40s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9968, recall-overall: 0.9975, f1-measure-overall: 0.9972, batch_loss: 40.9062, loss: 13.3182 ||: : 30it [09:12, 18.42s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9965, recall-overall: 0.9979, f1-measure-overall: 0.9972, batch_loss: 0.5000, loss: 8.9580 ||: : 29it [08:27, 17.51s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9973, recall-overall: 0.9978, f1-measure-overall: 0.9976, batch_loss: 0.6250, loss: 10.6955 ||: : 29it [08:08, 16.84s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9983, recall-overall: 0.9990, f1-measure-overall: 0.9986, batch_loss: 5.4375, loss: 9.3031 ||: : 30it [08:18, 16.63s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9978, recall-overall: 0.9982, f1-measure-overall: 0.9980, batch_loss: 6.3047, loss: 6.1776 ||: : 29it [08:19, 17.22s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9977, recall-overall: 0.9980, f1-measure-overall: 0.9979, batch_loss: 0.8438, loss: 5.7469 ||: : 29it [08:14, 17.04s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9975, recall-overall: 0.9976, f1-measure-overall: 0.9976, batch_loss: 9.0176, loss: 7.7605 ||: : 30it [08:18, 16.60s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9964, recall-overall: 0.9966, f1-measure-overall: 0.9965, batch_loss: 1.8438, loss: 11.5324 ||: : 30it [08:11, 16.37s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9962, recall-overall: 0.9969, f1-measure-overall: 0.9966, batch_loss: 9.9844, loss: 12.8704 ||: : 29it [08:27, 17.51s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9980, recall-overall: 0.9988, f1-measure-overall: 0.9984, batch_loss: 3.5742, loss: 4.8728 ||: : 30it [08:36, 17.23s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9993, recall-overall: 0.9993, f1-measure-overall: 0.9993, batch_loss: 0.7031, loss: 2.8980 ||: : 30it [08:26, 16.88s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9986, recall-overall: 0.9987, f1-measure-overall: 0.9986, batch_loss: 7.0625, loss: 4.2808 ||: : 30it [08:50, 17.69s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9988, recall-overall: 0.9990, f1-measure-overall: 0.9989, batch_loss: 2.1562, loss: 4.5667 ||: : 30it [08:08, 16.28s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9987, recall-overall: 0.9990, f1-measure-overall: 0.9988, batch_loss: 15.0625, loss: 3.0480 ||: : 30it [08:36, 17.22s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9986, recall-overall: 0.9989, f1-measure-overall: 0.9987, batch_loss: 21.6094, loss: 2.7449 ||: : 30it [08:18, 16.60s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9995, recall-overall: 0.9997, f1-measure-overall: 0.9996, batch_loss: 0.7812, loss: 2.5399 ||: : 29it [08:06, 16.78s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9995, recall-overall: 0.9995, f1-measure-overall: 0.9995, batch_loss: -0.0625, loss: 2.2463 ||: : 29it [08:13, 17.03s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9992, recall-overall: 0.9993, f1-measure-overall: 0.9992, batch_loss: 2.7969, loss: 3.0429 ||: : 30it [08:21, 16.71s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9997, recall-overall: 0.9998, f1-measure-overall: 0.9997, batch_loss: 2.4316, loss: 2.3025 ||: : 30it [08:30, 17.02s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9996, recall-overall: 0.9998, f1-measure-overall: 0.9997, batch_loss: 1.3281, loss: 4.6582 ||: : 29it [08:09, 16.89s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9973, recall-overall: 0.9980, f1-measure-overall: 0.9977, batch_loss: -0.0000, loss: 4.8893 ||: : 30it [08:36, 17.23s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9956, recall-overall: 0.9976, f1-measure-overall: 0.9966, batch_loss: 0.6875, loss: 4.2254 ||: : 30it [08:21, 16.71s/it]
    accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9980, recall-overall: 0.9981, f1-measure-overall: 0.9981, batch_loss: 0.0312, loss: 5.8634 ||: : 30it [08:10, 16.34s/it]
    accuracy: 0.9984, accuracy3: 0.9984, precision-overall: 0.9787, recall-overall: 0.9515, f1-measure-overall: 0.9649, batch_loss: 22304.5000, loss: 749.8296 ||: : 30it [08:32, 17.08s/it]
    accuracy: 0.9570, accuracy3: 0.9570, precision-overall: 0.2782, recall-overall: 0.4189, f1-measure-overall: 0.3343, batch_loss: 731722.4375, loss: 65948.9812 ||: : 30it [08:25, 16.85s/it]
    accuracy: 0.9383, accuracy3: 0.9383, precision-overall: 0.1668, recall-overall: 0.2775, f1-measure-overall: 0.2083, batch_loss: 778091.5625, loss: 337316.9677 ||: : 29it [08:08, 16.83s/it]
    Epoch    53: reducing learning rate of group 0 to 3.0000e-03.
    accuracy: 0.9668, accuracy3: 0.9669, precision-overall: 0.3510, recall-overall: 0.5322, f1-measure-overall: 0.4230, batch_loss: 77123.0000, loss: 253831.3728 ||: : 29it [08:23, 17.36s/it]
    accuracy: 0.9767, accuracy3: 0.9767, precision-overall: 0.4897, recall-overall: 0.6151, f1-measure-overall: 0.5453, batch_loss: -1.0000, loss: 137048.0448 ||: : 30it [08:35, 17.19s/it]
    accuracy: 0.9839, accuracy3: 0.9839, precision-overall: 0.6340, recall-overall: 0.7326, f1-measure-overall: 0.6798, batch_loss: 43615.0000, loss: 103847.1062 ||:  19%|#8        | 5/27 [01:36<07:03, 19.27s/it]
    
    opened by bratao 7
  • Should this work with Mixed precision training (AMP)

    Should this work with Mixed precision training (AMP)

    Hi just a question is this optimizer compatible with Mixed precision training or AMP. I tried to use in in combination with lucidrains' lightweight-gan implementation which uses the PyTorch version of this optimizer. But after a few 100 iterations my losses go to NaN and eventually causes a Division by Zero error. Don't see the same problem with using the standard adam optimizer

    opened by Mut1nyJD 6
  • 0.1.0 changes for ranger_adabelief

    0.1.0 changes for ranger_adabelief

    Hi @juntang-zhuang , super excited to try the new improvements. I saw that you did not updated the ranger version. Do you plan to add the improvements there too?

    opened by bratao 6
  • Suppressing weight decoupling and rectification messages

    Suppressing weight decoupling and rectification messages

    Is there a way to suppress these messages by setting some parameters explicitly when they are enabled?

    Weight decoupling enabled in AdaBelief
    Rectification enabled in AdaBelief
    

    I skimmed through the code and did not notice there is any parameter that we do so. I apologize if I have overlooked any part of the code/documentation. Thank you in advance for your reply.

    Environment

    • adabelief_pytorch 0.2.1
    • Python 3.8.10
    opened by gunsodo 1
  • Inconsistent computation of weight_decay and grad_residual among pytorch versions

    Inconsistent computation of weight_decay and grad_residual among pytorch versions

    Hi I was looking at the various versions you have in the pypi_packages folder and noticed that the order of computation of weight decay (which for some options modifies grad) and of grad_residual (which uses grad) differs for the different versions. In adabelief_pytorch0.0.5, adabelief_pytorch0.2.0, and adabelief_pytorch0.2.1 weight decay is done before computing grad_residual but in adabelief_pytorch0.1.0 it is done afterwards. It seems that adabelief_pytorch0.1.0 is more closely following what your paper described as the second-order momentum computation. Shouldn't the others be changes to align with adabelief_pytorch0.1.0?

    opened by sjscotti 5
  • Documentation (at least for TF) and weight_decouple is not an option

    Documentation (at least for TF) and weight_decouple is not an option

    Hiya,

    In the ReadME you say that Rectify is implemented as an option but the default is True. I would update the ReadME to reflect that.

    You also make it sound like weight_decouple is an available option in the TF version. But it isn't:

    | AdaBeliefOptimizer(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-14, weight_decay=0.0, rectify=True, amsgrad=False, sma_threshold=5.0, total_steps=0, warmup_proportion=0.1, min_lr=0.0, name='AdaBeliefOptimizer', print_change_log=True, **kwargs)

    I just get an error message when I try to set weight_decouple=True.

    Great work otherwise!

    opened by grofte 2
Owner
Juntang Zhuang
Juntang Zhuang
This is an official PyTorch implementation of Task-Adaptive Neural Network Search with Meta-Contrastive Learning (NeurIPS 2021, Spotlight).

NeurIPS 2021 (Spotlight): Task-Adaptive Neural Network Search with Meta-Contrastive Learning This is an official PyTorch implementation of Task-Adapti

Wonyong Jeong 15 Nov 21, 2022
PyTorch implementation for our NeurIPS 2021 Spotlight paper "Long Short-Term Transformer for Online Action Detection".

Long Short-Term Transformer for Online Action Detection Introduction This is a PyTorch implementation for our NeurIPS 2021 Spotlight paper "Long Short

null 77 Dec 16, 2022
[NeurIPS 2021 Spotlight] Code for Learning to Compose Visual Relations

Learning to Compose Visual Relations This is the pytorch codebase for the NeurIPS 2021 Spotlight paper Learning to Compose Visual Relations. Demo Imag

Nan Liu 88 Jan 4, 2023
[NeurIPS 2020] Official repository for the project "Listening to Sound of Silence for Speech Denoising"

Listening to Sounds of Silence for Speech Denoising Introduction This is the repository of the "Listening to Sounds of Silence for Speech Denoising" p

Henry Xu 40 Dec 20, 2022
UDP++ (ECCVW 2020 Oral), (Winner of COCO 2020 Keypoint Challenge).

UDP-Pose This is the pytorch implementation for UDP++, which won the Fisrt place in COCO Keypoint Challenge at ECCV 2020 Workshop. Top-Down Results on

null 20 Jul 29, 2022
[ICLR 2021, Spotlight] Large Scale Image Completion via Co-Modulated Generative Adversarial Networks

Large Scale Image Completion via Co-Modulated Generative Adversarial Networks, ICLR 2021 (Spotlight) Demo | Paper [NEW!] Time to play with our interac

Shengyu Zhao 373 Jan 2, 2023
[ICLR 2021 Spotlight Oral] "Undistillable: Making A Nasty Teacher That CANNOT teach students", Haoyu Ma, Tianlong Chen, Ting-Kuei Hu, Chenyu You, Xiaohui Xie, Zhangyang Wang

Undistillable: Making A Nasty Teacher That CANNOT teach students "Undistillable: Making A Nasty Teacher That CANNOT teach students" Haoyu Ma, Tianlong

VITA 71 Dec 28, 2022
Code for "The Intrinsic Dimension of Images and Its Impact on Learning" - ICLR 2021 Spotlight

dimensions Estimating the instrinsic dimensionality of image datasets Code for: The Intrinsic Dimensionaity of Images and Its Impact On Learning - Phi

Phil Pope 41 Dec 10, 2022
Official Implementation of 'UPDeT: Universal Multi-agent Reinforcement Learning via Policy Decoupling with Transformers' ICLR 2021(spotlight)

UPDeT Official Implementation of UPDeT: Universal Multi-agent Reinforcement Learning via Policy Decoupling with Transformers (ICLR 2021 spotlight) The

hhhusiyi 96 Dec 22, 2022
This codebase is the official implementation of Test-Time Classifier Adjustment Module for Model-Agnostic Domain Generalization (NeurIPS2021, Spotlight)

Test-Time Classifier Adjustment Module for Model-Agnostic Domain Generalization This codebase is the official implementation of Test-Time Classifier A

null 47 Dec 28, 2022
Anomaly Transformer: Time Series Anomaly Detection with Association Discrepancy" (ICLR 2022 Spotlight)

About Code release for Anomaly Transformer: Time Series Anomaly Detection with Association Discrepancy (ICLR 2022 Spotlight)

THUML @ Tsinghua University 221 Dec 31, 2022
git《Beta R-CNN: Looking into Pedestrian Detection from Another Perspective》(NeurIPS 2020) GitHub:[fig3]

Beta R-CNN: Looking into Pedestrian Detection from Another Perspective This is the pytorch implementation of our paper "[Beta R-CNN: Looking into Pede

null 35 Sep 8, 2021
Official implementation of "GS-WGAN: A Gradient-Sanitized Approach for Learning Differentially Private Generators" (NeurIPS 2020)

GS-WGAN This repository contains the implementation for GS-WGAN: A Gradient-Sanitized Approach for Learning Differentially Private Generators (NeurIPS

null 46 Nov 9, 2022
Diverse Image Captioning with Context-Object Split Latent Spaces (NeurIPS 2020)

Diverse Image Captioning with Context-Object Split Latent Spaces This repository is the PyTorch implementation of the paper: Diverse Image Captioning

Visual Inference Lab @TU Darmstadt 34 Nov 21, 2022
Official implementation for Likelihood Regret: An Out-of-Distribution Detection Score For Variational Auto-encoder at NeurIPS 2020

Likelihood-Regret Official implementation of Likelihood Regret: An Out-of-Distribution Detection Score For Variational Auto-encoder at NeurIPS 2020. T

Xavier 33 Oct 12, 2022
Official Pytorch implementation of 'GOCor: Bringing Globally Optimized Correspondence Volumes into Your Neural Network' (NeurIPS 2020)

Official implementation of GOCor This is the official implementation of our paper : GOCor: Bringing Globally Optimized Correspondence Volumes into You

Prune Truong 71 Nov 18, 2022
《Dual-Resolution Correspondence Network》(NeurIPS 2020)

Dual-Resolution Correspondence Network Dual-Resolution Correspondence Network, NeurIPS 2020 Dependency All dependencies are included in asset/dualrcne

Active Vision Laboratory 45 Nov 21, 2022
(NeurIPS 2020) Wasserstein Distances for Stereo Disparity Estimation

Wasserstein Distances for Stereo Disparity Estimation Accepted in NeurIPS 2020 as Spotlight. [Project Page] Wasserstein Distances for Stereo Disparity

Divyansh Garg 92 Dec 12, 2022
Official Implementation of Swapping Autoencoder for Deep Image Manipulation (NeurIPS 2020)

Swapping Autoencoder for Deep Image Manipulation Taesung Park, Jun-Yan Zhu, Oliver Wang, Jingwan Lu, Eli Shechtman, Alexei A. Efros, Richard Zhang UC

null 449 Dec 27, 2022