MADGRAD Optimization Algorithm For Tensorflow
This package implements the MadGrad Algorithm proposed in Adaptivity without Compromise: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic Optimization (Aaron Defazio and Samy Jelassi, 2021).
Table of Contents
About The Project
The MadGrad algorithm of optimization uses Dual averaging of gradients along with momentum based adaptivity to attain results that match or outperform Adam or SGD + momentum based algorithms. This project offers a Tensorflow implementation of the algorithm along with a few usage examples and tests.
Prerequisites
Prerequisites can be installed separately through the requirements.txt
file as below
pip install -r requirements.txt
Installation
This project is built with Python 3 and can be pip
installed directly
pip install tf-madgrad
Usage
To use the optimizer in any tf.keras model, you just need to import and instantiate the MadGrad
optimizer from the tf_madgrad
package.
from madgrad import MadGrad
# Create the architecture
inp = tf.keras.layers.Input(shape=shape)
...
op = tf.keras.layers.Dense(classes, activation=activation)
# Instantiate the model
model = tf.keras.models.Model(inp, op)
# Pass the MadGrad optimizer to the compile function
model.compile(optimizer=MadGrad(lr=0.01), loss=loss)
# Fit the keras model as normal
model.fit(...)
This implementation is also supported for distributed training using tf.strategy
See a MNIST example here
Contributing
Any and all contributions are welcome. Please raise an issue if the optimizer gives incorrect results or crashes unexpectedly during training.
License
Distributed under the MIT License. See LICENSE
for more information.
Contact
Feel free to reach out for any issues or requests related to this implementation
Darshan Deshpande - Email | LinkedIn
Citations
@misc{defazio2021adaptivity,
title={Adaptivity without Compromise: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic Optimization},
author={Aaron Defazio and Samy Jelassi},
year={2021},
eprint={2101.11075},
archivePrefix={arXiv},
primaryClass={cs.LG}
}