Elegy
Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.
Main Features
- Easy-to-use: Elegy provides a Keras-like high-level API that makes it very easy to do common tasks.
- Flexible: Elegy provides a functional Pytorch Lightning-like low-level API that provides maximal flexibility when needed.
- Agnostic: Elegy supports a variety of frameworks including Flax, Haiku, and Optax on the high-level API, and it is 100% framework-agnostic on the low-level API.
- Compatible: Elegy can consume a wide variety of common data sources including TensorFlow Datasets, Pytorch DataLoaders, Python generators, and Numpy pytrees.
For more information take a look at the Documentation.
Installation
Install Elegy using pip:
pip install elegy
For Windows users we recommend the Windows subsystem for linux 2 WSL2 since jax does not support it yet.
Quick Start: High-level API
Elegy's high-level API provides a very simple interface you can use by implementing following steps:
1. Define the architecture inside a Module
. We will use Flax Linen for this example:
import flax.linen as nn
import jax
class MLP(nn.Module):
@nn.compact
def call(self, x):
x = nn.Dense(300)(x)
x = jax.nn.relu(x)
x = nn.Dense(10)(x)
return x
2. Create a Model
from this module and specify additional things like losses, metrics, and optimizers:
import elegy, optax
model = elegy.Model(
module=MLP(),
loss=[
elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
elegy.regularizers.GlobalL2(l=1e-5),
],
metrics=elegy.metrics.SparseCategoricalAccuracy(),
optimizer=optax.rmsprop(1e-3),
)
3. Train the model using the fit
method:
model.fit(
x=X_train,
y=y_train,
epochs=100,
steps_per_epoch=200,
batch_size=64,
validation_data=(X_test, y_test),
shuffle=True,
callbacks=[elegy.callbacks.TensorBoard("summaries")]
)
Quick Start: Low-level API
In Elegy's low-level API lets you define exactly what goes on during training, testing, and inference. Lets define the test_step
to implement a linear classifier in pure jax:
1. Calculate our loss, logs, and states:
class LinearClassifier(elegy.Model):
# request parameters by name via depending injection.
# names: x, y_true, sample_weight, class_weight, states, initializing
def test_step(
self,
x, # inputs
y_true, # labels
states: elegy.States, # model state
initializing: bool, # if True we should initialize our parameters
):
rng: elegy.RNGSeq = states.rng
# flatten + scale
x = jnp.reshape(x, (x.shape[0], -1)) / 255
# initialize or use existing parameters
if initializing:
w = jax.random.uniform(
rng.next(), shape=[np.prod(x.shape[1:]), 10]
)
b = jax.random.uniform(rng.next(), shape=[1])
else:
w, b = states.net_params
# model
logits = jnp.dot(x, w) + b
# categorical crossentropy loss
labels = jax.nn.one_hot(y_true, 10)
loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true)
# metrics
logs = dict(
accuracy=accuracy,
loss=loss,
)
return loss, logs, states.update(net_params=(w, b))
2. Instantiate our LinearClassifier
with an optimizer:
model = LinearClassifier(
optimizer=optax.rmsprop(1e-3),
)
3. Train the model using the fit
method:
model.fit(
x=X_train,
y=y_train,
epochs=100,
steps_per_epoch=200,
batch_size=64,
validation_data=(X_test, y_test),
shuffle=True,
callbacks=[elegy.callbacks.TensorBoard("summaries")]
)
Using Jax Frameworks
It is straightforward to integrate other functional JAX libraries with this low-level API:
class LinearClassifier(elegy.Model):
def test_step(
self, x, y_true, states: elegy.States, initializing: bool
):
rng: elegy.RNGSeq = states.rng
x = jnp.reshape(x, (x.shape[0], -1)) / 255
if initializing:
logits, variables = self.module.init_with_output(
{"params": rng.next(), "dropout": rng.next()}, x
)
else:
variables = dict(params=states.net_params, **states.net_states)
logits, variables = self.module.apply(
variables, x, rngs={"dropout": rng.next()}, mutable=True
)
net_states, net_params = variables.pop("params")
labels = jax.nn.one_hot(y_true, 10)
loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == y_true)
logs = dict(accuracy=accuracy, loss=loss)
return loss, logs, states.update(net_params=net_params, net_states=net_states)
More Info
- Getting Started: High-level API tutorial.
- Getting Started: Low-level API tutorial.
- Elegy's Documentation.
- The examples directory.
- What is Jax?
Examples
To run the examples first install some required packages:
pip install -r examples/requirements.txt
Now run the example:
python examples/flax_mnist_vae.py
Contributing
Deep Learning is evolving at an incredible pace, there is so much to do and so few hands. If you wish to contribute anything from a loss or metric to a new awesome feature for Elegy just open an issue or send a PR! For more information check out our Contributing Guide.
About Us
We are some friends passionate about ML.
License
Apache
Citing Elegy
To cite this project:
BibTeX
@software{elegy2020repository,
author = {PoetsAI},
title = {Elegy: A framework-agnostic Trainer interface for the Jax ecosystem},
url = {https://github.com/poets-ai/elegy},
version = {0.5.0},
year = {2020},
}
Where the current version may be retrieved either from the Release
tag or the file elegy/__init__.py and the year corresponds to the project's release year.