Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Overview

Elegy

PyPI Status Badge Coverage PyPI - Python Version Documentation Code style: black Contributions welcome Status


Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Main Features

  • Easy-to-use: Elegy provides a Keras-like high-level API that makes it very easy to do common tasks.
  • Flexible: Elegy provides a functional Pytorch Lightning-like low-level API that provides maximal flexibility when needed.
  • Agnostic: Elegy supports a variety of frameworks including Flax, Haiku, and Optax on the high-level API, and it is 100% framework-agnostic on the low-level API.
  • Compatible: Elegy can consume a wide variety of common data sources including TensorFlow Datasets, Pytorch DataLoaders, Python generators, and Numpy pytrees.

For more information take a look at the Documentation.

Installation

Install Elegy using pip:

pip install elegy

For Windows users we recommend the Windows subsystem for linux 2 WSL2 since jax does not support it yet.

Quick Start: High-level API

Elegy's high-level API provides a very simple interface you can use by implementing following steps:

1. Define the architecture inside a Module. We will use Flax Linen for this example:

import flax.linen as nn
import jax

class MLP(nn.Module):
    @nn.compact
    def call(self, x):
        x = nn.Dense(300)(x)
        x = jax.nn.relu(x)
        x = nn.Dense(10)(x)
        return x

2. Create a Model from this module and specify additional things like losses, metrics, and optimizers:

import elegy, optax

model = elegy.Model(
    module=MLP(),
    loss=[
        elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
        elegy.regularizers.GlobalL2(l=1e-5),
    ],
    metrics=elegy.metrics.SparseCategoricalAccuracy(),
    optimizer=optax.rmsprop(1e-3),
)

3. Train the model using the fit method:

model.fit(
    x=X_train,
    y=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[elegy.callbacks.TensorBoard("summaries")]
)

Quick Start: Low-level API

In Elegy's low-level API lets you define exactly what goes on during training, testing, and inference. Lets define the test_step to implement a linear classifier in pure jax:

1. Calculate our loss, logs, and states:

class LinearClassifier(elegy.Model):
    # request parameters by name via depending injection.
    # names: x, y_true, sample_weight, class_weight, states, initializing
    def test_step(
        self,
        x, # inputs
        y_true, # labels
        states: elegy.States, # model state
        initializing: bool, # if True we should initialize our parameters
    ):  
        rng: elegy.RNGSeq = states.rng
        # flatten + scale
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        # initialize or use existing parameters
        if initializing:
            w = jax.random.uniform(
                rng.next(), shape=[np.prod(x.shape[1:]), 10]
            )
            b = jax.random.uniform(rng.next(), shape=[1])
        else:
            w, b = states.net_params
        # model
        logits = jnp.dot(x, w) + b
        # categorical crossentropy loss
        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
        accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true)
        # metrics
        logs = dict(
            accuracy=accuracy,
            loss=loss,
        )
        return loss, logs, states.update(net_params=(w, b))

2. Instantiate our LinearClassifier with an optimizer:

model = LinearClassifier(
    optimizer=optax.rmsprop(1e-3),
)

3. Train the model using the fit method:

model.fit(
    x=X_train,
    y=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[elegy.callbacks.TensorBoard("summaries")]
)

Using Jax Frameworks

It is straightforward to integrate other functional JAX libraries with this low-level API:

class LinearClassifier(elegy.Model):
    def test_step(
        self, x, y_true, states: elegy.States, initializing: bool
    ):
        rng: elegy.RNGSeq = states.rng
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        if initializing:
            logits, variables = self.module.init_with_output(
                {"params": rng.next(), "dropout": rng.next()}, x
            )
        else:
            variables = dict(params=states.net_params, **states.net_states)
            logits, variables = self.module.apply(
                variables, x, rngs={"dropout": rng.next()}, mutable=True
            )
        net_states, net_params = variables.pop("params")
        
        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
        accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == y_true)

        logs = dict(accuracy=accuracy, loss=loss)
        return loss, logs, states.update(net_params=net_params, net_states=net_states)

More Info

Examples

To run the examples first install some required packages:

pip install -r examples/requirements.txt

Now run the example:

python examples/flax_mnist_vae.py 

Contributing

Deep Learning is evolving at an incredible pace, there is so much to do and so few hands. If you wish to contribute anything from a loss or metric to a new awesome feature for Elegy just open an issue or send a PR! For more information check out our Contributing Guide.

About Us

We are some friends passionate about ML.

License

Apache

Citing Elegy

To cite this project:

BibTeX

@software{elegy2020repository,
author = {PoetsAI},
title = {Elegy: A framework-agnostic Trainer interface for the Jax ecosystem},
url = {https://github.com/poets-ai/elegy},
version = {0.5.0},
year = {2020},
}

Where the current version may be retrieved either from the Release tag or the file elegy/__init__.py and the year corresponds to the project's release year.

Comments
  • Framework Agnostic API: Introduces a new low-level API, removes the dependency between Model and Module, adds support for Flax and Haiku, simplifies hooks.

    Framework Agnostic API: Introduces a new low-level API, removes the dependency between Model and Module, adds support for Flax and Haiku, simplifies hooks.

    As noted below, this PR contains the following features:

    • It turns Elegy into a framework agnostic library by removing the dependencies between elegy.Model and elegy.Module, it proposes the GeneralizedModule API and implements it for Flax, Haiku, Elegy Module types, and regular python functions.
    • It introduces a new low-level API similar to Pytorch Lightning that lets users manually override the core parts of the training loop when maximal flexibility is required.
    • General changes that enable the framework-agnostic mindset.
    • Many quality of life changes like standardization of hooks, simplification of the Module system, etc.

    Tasks:

    • [x] Create hooks module
    • [x] Refactor Model with low-level API and remove Module dependencies
    • [x] Refactor Module to use hooks
    • [x] Create GeneralizedModule and GeneralizedOptimizer Inferfaces
    • [x] Implement GeneralizedModule for flax.linen.Module
    • [x] Implement GeneralizedModule for elegy.Module
    • [x] Implement GeneralizedModule for haiku.Module
    • [x] Implement GeneralizedOptimizer for optax.GradientTransformation
    • [x] Implement GeneralizedOptimizer for elegy.Optimizer
    • [x] Fix Model.summary
    • [x] Fix tests
    • [x] Fix examples
    • [ ] Fix README
    • [ ] Fix guides
    • [ ] Fix docstrings
    opened by cgarciae 27
  • WGAN-GP low-level API example

    WGAN-GP low-level API example

    A more extensive example using the new low-level API: Wasserstein-GAN with Gradient Penalty (WGAN-GP) trained on the CelebA dataset.

    Some good generated images: epoch-0079 epoch-0084 epoch-0089

    Some notes:

    • I first tried to train a DCGAN which uses binary crossentropy but I've run into balancing issues. The discriminator quickly becomes too good so that the generator does not learn anything. The same model implemented in PyTorch or TensorFlow works. Most modern GANs don't use the WGAN loss anymore, most use BCE.
    • I'm still in favor of making Module.apply() return init(). It's just too much boilerplate to use an if-else every time. I avoided it by manually calling wgan.states = wgan.init(...) after model instantiation which I think is also not nice.
    • Can we make Module.apply() accept params and states separately instead of collections. It's annoying having to construct a dict {'params':params, 'states':states} every time
    • It would be nice if elegy.States was a dict so that the user can decide by themself what to put into it. With GANs where you have to manage generator and discriminator states separately one has to always split them like (g_states, d_states) = net_states which is again annoying
    • Model.save() fails on this model. Partially due to the extra jitted functions but even when I remove them, cloudpickle chokes on _HooksContext

    @cgarciae I'm not completely sure I've used the low-level API correctly, maybe you can take a closer look?

    opened by alexander-g 11
  • Add learning rate logging

    Add learning rate logging

    Implements the same functionality from #131 using only minor modifications to elegy.Optimizer.

    • [x] Add lr_schedule and steps_per_epoch to Optimizer.
    • [x] Implement Optimizer.get_effective_learning_rate
    • [x] Copy logging code from #131
    • [x] Add documentation

    @alexander-g Here is a proposal that is a bit simpler, closer to what I mentioned in #124. What do you think? @charlielito should we log the learning rate automatically if available or should we create a Callback?

    opened by cgarciae 9
  • Question: how to set the random state when calling model.predict(...)

    Question: how to set the random state when calling model.predict(...)

    Not sure if this is the right place to post this...

    I have built and trained a VAE. When calling model.predict(x=test_set), I would like to make multiple predictions for each item in the test set (because VAE's are probabilistic). That way I can look at the distribution of predictions for each item in the test_set.

    The call() for the VAE includes the line
    intrinsic_latents = mean + stds * jax.random.normal(self.next_key(), mean.shape).

    I haven't been able to find an explanation for how self.next_key() works or how to change the random seed on each call so that I can get different predictions. I could rewrite the code so that random seeds are explicitly passed, but I assume there is some functionality build into elegy to make this easy?

    Could someone explain how this works, or point me to the documentation explaining it?

    Thanks!

    opened by jfcrenshaw 8
  • Examples Cleanup

    Examples Cleanup

    • refactored examples/imagenet/resnet_imagenet.py to accept parameters instead of modifying them inside the script
    • added README.md for examples/imagenet/
    • removed unnecessary Lambda class from examples/mnist.py
    • moved global average pooling in examples/mnist_conv.py before the Linear layer
    opened by alexander-g 7
  • Resnet

    Resnet

    • ResNet model architecture and an example for training on ImageNet
      • code is mostly adapted from the flax library
      • pretrained ResNet50 with 76.5% accuracy
      • pretrained ResNet18 with 68.7% accuracy
    • Experimental support for mixed precision: previously all layers set their parameters' dtype to the input's dtype. This is incorrect, for numerical stability reasons all parameters should be float32 even when performing float16 computations. See more here.
    • Some issues I had during training:
      • There seems to be a memory leak during training, RAM constantly increased
      • I had to use smaller batch sizes than when training with flax or with TensorFlow before maxing out GPU memory (64 instead of 128 for ResNet50 on a RTX2080Ti). This might be of course due to a mistake in my code, but the number of parameters is identical to the flax and PyTorch versions, so I think it might be somewhere else
    opened by alexander-g 7
  • [Bug] Problem with computing metrics

    [Bug] Problem with computing metrics

    Describe the bug Hi, when I am using the fit function I have an error message that the update function is not provided with y_true and y_pred. It seems to be coming from the metrics of the model, because if I comment the metrics line I have no error

    TypeError: update() missing 2 required positional arguments: 'y_true' and 'y_pred'
    

    Minimal code to reproduce Small snippet that contains a minimal amount of code.

    import jax
    import jax.numpy as jnp
    import ml_collections
    import numpy as np
    import optax
    import elegy as eg
    
    
    class eCNN(eg.Module):
        """A simple CNN model."""
    
        @eg.compact
        def __call__(self, x):
            x=eg.Conv(10,kernel_size=(10,))(x)
            x=jax.nn.relu(x)
            x = eg.Linear(1)(x)
            x=jax.nn.sigmoid(x)
            return x
    
    n=200
    X_train = np.random.rand(n*100).reshape(n,100)
    y_train = np.random.rand(n).reshape(n,1)
    print(X_train.shape)
    print(y_train.shape)
    
    model = eg.Model(
        module=eCNN(),
        loss=[
            eg.losses.MeanSquaredError(),
        ],
        metrics=eg.metrics.MeanSquareError(),  #Line to be commented to get rid of the error
        optimizer=optax.rmsprop(1e-3),
    )
    
    model.fit(X_train,y_train,
        epochs=10,
        batch_size=20,
        #validation_data=0.1,
        shuffle=False,
        callbacks=[eg.callbacks.TensorBoard("summaries")]
        )
    

    Library Info Please provide os info and elegy version.

    import elegy
    print(elegy.__version__) 
    # 0.8.4
    
    bug 
    opened by organic-chemistry 6
  • Multi-gpu with pmap docs

    Multi-gpu with pmap docs

    One of the selling points of jax is the pmap transformation, but best practices around actually getting your training loop parallelizable still is confusing. What is elegy's story around multigpu training? Is it possible to get to pytorch-lightning like api as a single arg to model.fit?

    opened by sooheon 6
  • SCCE fix for bug in Jax<0.2.7

    SCCE fix for bug in Jax<0.2.7

    Small fix for a bug in Jax<0.2.7 where jax.lax.take_along_axis gives incorrect results for uint8 indices. Very relevant for semantic segmentation.

    Alternatively consider updating Jax

    opened by alexander-g 6
  • Dataset & DataLoader

    Dataset & DataLoader

    Dataset and parallel DataLoader API similar to PyTorch. Can be used with Model.fit()

    class MyDataset(elegy.data.Dataset):
        def __len__(self):
            return 128
    
        def __getitem__(self, i):
            #dummy data
            return np.random.random([224, 224, 3]),  np.random.randint(10)
    
    ds     = MyDataset()
    loader = elegy.data.DataLoader(ds, batch_size=8, n_workers=8, worker_type='thread', shuffle=True)
    
    batch = next(iter(loader))
    assert batch[0].shape == (8,224,224,3)
    assert batch[1].shape == (8,)
    assert len(loader) == 16
    
    model.fit(loader, epochs=10)
    
    opened by alexander-g 6
  • Implemented BinaryCrossentropy metric

    Implemented BinaryCrossentropy metric

    Updates:

    • Created BinaryCrossentropy metric
    • Created basic tests for BinaryCrossentropy metric (passing)
    • Created docs for BinaryCrossentropy metric
    • Refactored main docs by balancing files and correcting language typos
    documentation 
    opened by sebasarango1180 6
  • use poetry-core

    use poetry-core

    poetry-core is intended to be a light weight, fully compliant, self-contained package allowing PEP 517 compatible build frontends to build Poetry managed projects.

    Using poetry-core allows distribution packages to depend only on the build backend.

    opened by dotlambda 0
  • Documentation/API reference not accessible via project website[Bug]

    Documentation/API reference not accessible via project website[Bug]

    Hi, It looks like I can't really access the API reference for Elegy. The corresponding link on the project's website simply takes me back to the main page (https://poets-ai.github.io/elegy/). Any idea what's up?

    bug 
    opened by geomlyd 0
  • [Bug] elegy does not work with latest haiku version

    [Bug] elegy does not work with latest haiku version

    Describe the bug When I type 'import elegy' I get this error

     File "/home/kpmurphy/mambaforge/lib/python3.10/site-packages/elegy/generalized_module/haiku_module.py", line 4, in <module>
        from haiku._src.base import current_bundle_name
    

    Minimal code to reproduce

    import elegy
    

    Expected behavior A clear and concise description of what you expected to happen.

    Library Info Please provide os info and elegy version.

    >> 
    >>> jax.__version__
    '0.2.28'
    >>> haiku.__version__
    '0.0.9.dev'
    >>> elegy.__version__. #  elegy-0.5.0-py3-none-any.whl 
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    NameError: name 'elegy' is not defined
    >>> 
    

    Screenshots

    Screen Shot 2022-10-03 at 2 33 21 PM

    Additional context Add any other context about the problem here.

    bug 
    opened by murphyk 5
  • CSVLogger iteration over a 0-d array

    CSVLogger iteration over a 0-d array

    Describe the bug When using the CSVLogger callback, elegy crashes at the end of the first epoch.

    Minimal code to reproduce

    import elegy as eg
    import optax
    import numpy as np
    
    x = np.random.randn(64, 1)
    y = np.random.randn(64, 1)
    
    model = eg.Model(
        eg.nn.Linear(1),
        loss=eg.losses.MeanSquaredError(),
        optimizer=optax.adam(1e-3),
    )
    
    hist = model.fit(
        x,
        y,
        epochs=10,
        callbacks=[
            eg.callbacks.CSVLogger("train.csv"), <-- commenting
        ]
    )
    

    Stack trace:

    Epoch 1/10
    2/2 [==============================] - ETA: 0s - loss: 1.3408 - mean_squared_error_loss: 1.3408
    Traceback (most recent call last):
      File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 194, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "/home/scott/Documents/phd/geom/pde/csv.py", line 14, in <module>
        hist = model.fit(
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_base.py", line 465, in fit
        callbacks.on_epoch_end(epoch, epoch_logs)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/callback_list.py", line 221, in on_epoch_end
        callback.on_epoch_end(epoch, logs)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py", line 93, in on_epoch_end
        row_dict.update((key, handle_value(logs[key])) for key in self.keys)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py", line 93, in <genexpr>
        row_dict.update((key, handle_value(logs[key])) for key in self.keys)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py", line 68, in handle_value
        return '"[%s]"' % (", ".join(map(str, k)))
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/jax/_src/device_array.py", line 245, in __iter__
        raise TypeError("iteration over a 0-d array")  # same as numpy error
    TypeError: iteration over a 0-d array
    

    Expected behavior Should treat 0-d array as scalar.

    Library Info Please provide os info and elegy version. python version: 3.8.13 elegy version: 0.8.6 treex version: 0.6.10

    Additional context More detailed error information shows the error occurs because the array is a jax DeviceArray and so the test for zero dimensional array uses the line

    is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
    
    │ /home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py:6 │
    │ 8 in handle_value                                                                                │
    │                                                                                                  │
    │    65 │   │   │   if isinstance(k, six.string_types):                                            │
    │    66 │   │   │   │   return k                                                                   │
    │    67 │   │   │   elif isinstance(k, tp.Iterable) and not is_zero_dim_ndarray:                   │
    │ ❱  68 │   │   │   │   return '"[%s]"' % (", ".join(map(str, k)))                                 │
    │    69 │   │   │   else:                                                                          │
    │    70 │   │   │   │   return k                                                                   │
    │    71                                                                                            │
    │                                                                                                  │
    │ ╭──────────────────────────── locals ─────────────────────────────╮                              │
    │ │ is_zero_dim_ndarray = False                                     │                              │
    │ │                   k = DeviceArray(4.8264385e-05, dtype=float32) │                              │
    │ ╰─────────────────────────────────────────────────────────────────╯                              │
    │                                                                                                  │
    │ /home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/jax/_src/device_array.py:245 in │
    │ __iter__                                                                                         │
    │                                                                                                  │
    │   242                                                                                            │
    │   243   def __iter__(self):                                                                      │
    │   244 │   if self.ndim == 0:                                                                     │
    │ ❱ 245 │     raise TypeError("iteration over a 0-d array")  # same as numpy error                 │
    │   246 │   else:                                                                                  │
    │   247 │     return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack())            │
    │   248                                                                                            │
    │                                                                                                  │
    │ ╭───────────────────── locals ─────────────────────╮                                             │
    │ │ self = DeviceArray(4.8264385e-05, dtype=float32) │                                             │
    │ ╰──────────────────────────────────────────────────╯                                             │
    ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
    TypeError: iteration over a 0-d array
    
    bug 
    opened by ScottAlexanderCameron 0
  • Metrics ignore

    Metrics ignore "on" keyword arg

    Describe the bug I have an application where I need to output multiple values from a network, which I am doing using a dictionary and using the on keyword argument. This works fine for the loss functions but not for metrics.

    Minimal code to reproduce Small snippet that contains a minimal amount of code.

    import elegy as eg
    import optax
    import numpy as np
    
    
    def data_generator():
        while True:
            yield (
                np.random.randn(10, 1),
                {"target": {"y": np.random.randn(10, 1)}},
            )
    
    
    class MyModule(eg.Module):
        @eg.compact
        def __call__(self, x):
            return {"y": eg.nn.Linear(1)(x)}
    
    
    model = eg.Model(
        MyModule(),
        loss=eg.losses.MeanSquaredError(on="y"),
        metrics=eg.metrics.MeanAbsoluteError(on="y"),  #  <-- works fine without this line
        optimizer=optax.adam(1e-3),
    )
    
    hist = model.fit(
        data_generator(),
        steps_per_epoch=10,
        epochs=10,
    )
    

    Stack trace:

    Traceback (most recent call last):
      File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 194, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "/home/scott/Documents/phd/geom/pde/metric.py", line 27, in <module>
        hist = model.fit(
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_base.py", line 417, in fit
        tmp_logs = self.train_on_batch(
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_core.py", line 617, in train_on_batch
        logs, model = train_step_fn(self, inputs, labels)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_core.py", line 412, in _static_train_step
        return model.train_step(inputs, labels)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 306, in train_step
        grads, (logs, model) = grad_fn(params, model, inputs, labels)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 278, in loss_fn
        loss, logs, model = model.test_step(inputs, labels)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 248, in test_step
        batch_loss_and_logs.update(
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/loss_and_logs.py", line 78, in update
        self.metrics.update(**metrics_kwargs)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/metrics.py", line 44, in update
        metric.update(**metric_kwargs)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/mean_absolute_error.py", line 83, in update
        values = _mean_absolute_error(preds, target)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/mean_absolute_error.py", line 20, in _mean_absolute_error
        target = target.astype(preds.dtype)
    AttributeError: 'dict' object has no attribute 'astype'
    

    Expected behavior Should produce the same result as if the dictionaries are removed and the on arg not specified.

    Library Info Please provide os info and elegy version. python version: 3.8.13 elegy version: 0.8.6 treex version: 0.6.10

    Additional context From my digging the cause seems to be due to the Metric.update() method being called instead of the __call__ method.

    bug 
    opened by ScottAlexanderCameron 0
  • [Bug] Elegy crash on GPU

    [Bug] Elegy crash on GPU

    Describe the bug

    Running mnist_cnn.py in the example dir crash the instance at the end of the first epoch.

    This was previously reported on Colab GPU instance. But I can reproduce this on CLI too.

    Running on CPU does not have this problem.

    Running on eager mode with GPU does not have this problem.

    Minimal code to reproduce

    python mnist_cnn.py
    

    Expected behavior Not stuck.

    Library Info CentOS Linux release 7.6.1810 elegy 0.8.6

    Additional context absl-py==1.2.0 aiohttp==3.8.1 aiosignal==1.2.0 async-timeout==4.0.2 attrs==22.1.0 certifi==2021.10.8 charset-normalizer==2.1.1 chex==0.1.4 click==8.1.3 cloudpickle==1.6.0 colorama==0.4.5 commonmark==0.9.1 cycler==0.11.0 datasets==2.4.0 dill==0.3.5.1 dm-tree==0.1.7 docker-pycreds==0.4.0 einops==0.4.1 elegy==0.8.6 etils==0.7.1 filelock==3.8.0 flax==0.4.2 fonttools==4.36.0 frozenlist==1.3.1 fsspec==2022.7.1 gitdb==4.0.9 GitPython==3.1.27 h5py==3.6.0 huggingface-hub==0.8.1 idna==3.3 importlib-resources==5.9.0 jax==0.3.16 jaxlib==0.3.15+cuda11.cudnn82 kiwisolver==1.4.4 matplotlib==3.5.3 msgpack==1.0.4 multidict==6.0.2 multiprocess==0.70.13 numpy==1.22.3 opt-einsum==3.3.0 optax==0.1.3 packaging==21.3 pandas==1.4.3 pathtools==0.1.2 Pillow==9.2.0 promise==2.3 protobuf==3.20.1 psutil==5.9.1 pyarrow==9.0.0 Pygments==2.13.0 pyparsing==3.0.9 python-dateutil==2.8.2 pytz==2022.2.1 PyYAML==6.0 requests==2.28.1 responses==0.18.0 rich==11.2.0 scipy==1.8.0 sentry-sdk==1.9.5 setproctitle==1.3.2 shortuuid==1.0.9 six==1.16.0 smmap==5.0.0 tensorboardX==2.5.1 toolz==0.12.0 tqdm==4.64.0 treeo==0.0.10 treex==0.6.10 typing_extensions==4.3.0 urllib3==1.26.11 wandb==0.12.21 xxhash==3.0.0 yarl==1.8.1 zipp==3.8.1

    bug 
    opened by jiyuuchc 2
Releases(0.8.6)
  • 0.8.6(Mar 23, 2022)

    🚀 Features

    • Weights and Biases Callback for Elegy
      • PR: #220

    🐛 Fixes

    • Docs typos
      • PR: #222
    • Donate model's memory buffer to jit/pmap functions.
      • PR: #226
    • Lazy load wandb
      • PR: #228
    Source code(tar.gz)
    Source code(zip)
  • 0.8.5(Feb 23, 2022)

  • 0.8.4(Dec 14, 2021)

  • 0.8.3(Dec 13, 2021)

  • 0.8.2(Dec 13, 2021)

  • 0.8.1(Nov 8, 2021)

    Elegy is now based on Treex 🎉

    Changes

    • Remove the module, nn, metrics, and losses from Elegy, instead Elegy reexports these modules from Treex.
    • GeneralizedModule and friends are gone, to use Flax Modules use the elegy.nn.FlaxModule wrapper.
    • Low level API is massively simplified:
      • States is removed, since Model is a pytree all parameters are tracked automatically thanks to Treex / Treeo.
      • All static state arguments (training, initializing) are removed, Modules can simply use self.training to pick their training state and self.initializing() to check whether they are initializing.
      • Signature for pred_step, test_step, and train_step now simply consists of inputs and labels, where labels is a dict that can contain additional keys like sample_weight or class_weight as required by the losses and metrics.
    • Adds the DistributedStrategy class which currently has 3 instances
      • Eager: Runs model in a single device in eager mode (no jit)
      • JIT: Runs model in a single device with jit
      • DataParallel: Run the model in multiple devices using pmap.
    • Adds methods to change the model's distributed strategy:
      • .distributed(strategy = DataParallel): changes the distributed strategy, DataParallel used by default.
      • .local(): changes the distributed strategy to JIT.
      • .eager(): changes the distributed strategy to Eager.
    • Removes the .eager field in favor of the .eager() method.
    Source code(tar.gz)
    Source code(zip)
  • 0.7.4(Jun 1, 2021)

  • 0.7.2(Mar 10, 2021)

  • 0.7.1(Mar 1, 2021)

  • 0.7.0(Feb 22, 2021)

    Features

    • init now only called once internally and required to be called explicitly by the user under certain circumstances
    • summary now uses jax.eval_shape under the hood so its super fast since it doesn't allocate memory or perform any computations on the device.

    Merged pull requests:

    • Fix notebook #166 (cgarciae)
    • Single Initialization: Removes the current progressive initialization in favor of a single underlying call to init_step. #165 (cgarciae)
    Source code(tar.gz)
    Source code(zip)
  • 0.6.0(Feb 14, 2021)

  • 0.5.0(Feb 8, 2021)

    This version simplifies parts of the low-level API in spirit of what was introduced in 0.4.0 to provide a more homogeneous and simpler experience.

    Merged pull requests:

    • Improve States: uses __dict__ so States works with vars #159 (cgarciae)
    • Simplify API: Cleans-up some API details around Model and Module #158 (cgarciae)
    Source code(tar.gz)
    Source code(zip)
  • 0.4.1(Feb 3, 2021)

  • 0.4.0(Feb 1, 2021)

    Implemented enhancements:

    • [Feature Request] Monitoring learning rates #124

    Merged pull requests:

    Source code(tar.gz)
    Source code(zip)
  • 0.3.0(Dec 17, 2020)

    Implemented enhancements:

    • elegy.nn.Sequential docs not clear #107
    • [Feature Request] Community example repo. #98

    Fixed bugs:

    • [Bug] Accuracy from Model.evaluate() is inconsistent with manually computed accuracy #109
    • Exceptions in "Getting Started" colab notebook #104

    Closed issues:

    • l2_normalize #102
    • Need some help for contributing new losses. #93
    • Document Sum #62
    • Binary Accuracy Metric #58
    • Automate generation of API Reference folder structure #19
    • Implement Model.summary #3

    Merged pull requests:

    Source code(tar.gz)
    Source code(zip)
  • 0.2.2(Aug 31, 2020)

  • 0.2.1(Aug 25, 2020)

  • 0.2.0(Aug 17, 2020)

  • 0.1.5(Jul 28, 2020)

    • Mean Absolute Percentage Error Implementation @Ciroye
    • Adds elegy.nn.Linear, elegy.nn.Conv2D, elegy.nn.Flatten, elegy.nn.Sequential @cgarciae
    • Add Elegy hooks @cgarciae
    • Improves Tensorboard support @Davidnet
    • Added coverage metrics to CI @charlielito
    Source code(tar.gz)
    Source code(zip)
  • 0.1.4(Jul 24, 2020)

    • Adds elegy.metrics.BinaryCrossentropy @sebasarango1180
    • Adds elegy.nn.Dropout and elegy.nn.BatchNormalization @cgarciae
    • Improves documentation
    • Fixes bug that cause error when using is_training via dependency injection on Model.predict.
    Source code(tar.gz)
    Source code(zip)
  • 0.1.3(Jul 23, 2020)

PyTorch Personal Trainer: My framework for deep learning experiments

Alex's PyTorch Personal Trainer (ptpt) (name subject to change) This repository contains my personal lightweight framework for deep learning projects

Alex McKinney 8 Jul 14, 2022
PyTorch trainer and model for Sequence Classification

PyTorch-trainer-and-model-for-Sequence-Classification After cloning the repository, modify your training data so that the training data is a .csv file

NhanTieu 2 Dec 9, 2022
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
FEDn is an open-source, modular and ML-framework agnostic framework for Federated Machine Learning

FEDn is an open-source, modular and ML-framework agnostic framework for Federated Machine Learning (FedML) developed and maintained by Scaleout Systems. FEDn enables highly scalable cross-silo and cross-device use-cases over FEDn networks.

Scaleout 75 Nov 9, 2022
TorchIO is a Medical image preprocessing and augmentation toolkit for deep learning. Part of the PyTorch Ecosystem.

Medical image preprocessing and augmentation toolkit for deep learning. Part of the PyTorch Ecosystem.

Fernando Pérez-García 1.6k Jan 6, 2023
Supervised domain-agnostic prediction framework for probabilistic modelling

A supervised domain-agnostic framework that allows for probabilistic modelling, namely the prediction of probability distributions for individual data

The Alan Turing Institute 112 Oct 23, 2022
An Agnostic Computer Vision Framework - Pluggable to any Training Library: Fastai, Pytorch-Lightning with more to come

IceVision is the first agnostic computer vision framework to offer a curated collection with hundreds of high-quality pre-trained models from torchvision, MMLabs, and soon Pytorch Image Models. It orchestrates the end-to-end deep learning workflow allowing to train networks with easy-to-use robust high-performance libraries such as Pytorch-Lightning and Fastai

airctic 789 Dec 29, 2022
Objax Apache-2Objax (🥉19 · ⭐ 580) - Objax is a machine learning framework that provides an Object.. Apache-2 jax

Objax Tutorials | Install | Documentation | Philosophy This is not an officially supported Google product. Objax is an open source machine learning fr

Google 729 Jan 2, 2023
A lossless neural compression framework built on top of JAX.

Kompressor Branch CI Coverage main (active) main development A neural compression framework built on top of JAX. Install setup.py assumes a compatible

Rosalind Franklin Institute 2 Mar 14, 2022
ROS-UGV-Control-Interface - Control interface which can be used in any UGV

ROS-UGV-Control-Interface Cam Closed: Cam Opened:

Ahmet Fatih Akcan 1 Nov 4, 2022
Implementation of the paper "Language-agnostic representation learning of source code from structure and context".

Code Transformer This is an official PyTorch implementation of the CodeTransformer model proposed in: D. Zügner, T. Kirschstein, M. Catasta, J. Leskov

Daniel Zügner 131 Dec 13, 2022
Code for the paper Task Agnostic Morphology Evolution.

Task-Agnostic Morphology Optimization This repository contains code for the paper Task-Agnostic Morphology Evolution by Donald (Joey) Hejna, Pieter Ab

Joey Hejna 18 Aug 4, 2022
A task-agnostic vision-language architecture as a step towards General Purpose Vision

Towards General Purpose Vision Systems By Tanmay Gupta, Amita Kamath, Aniruddha Kembhavi, and Derek Hoiem Overview Welcome to the official code base f

AI2 79 Dec 23, 2022
MODALS: Modality-agnostic Automated Data Augmentation in the Latent Space

Update (20 Jan 2020): MODALS on text data is avialable MODALS MODALS: Modality-agnostic Automated Data Augmentation in the Latent Space Table of Conte

null 38 Dec 15, 2022
ICCV2021 Oral SA-ConvONet: Sign-Agnostic Optimization of Convolutional Occupancy Networks

Sign-Agnostic Convolutional Occupancy Networks Paper | Supplementary | Video | Teaser Video | Project Page This repository contains the implementation

null 63 Nov 18, 2022