Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.

Overview

Equinox

Callable PyTrees and filtered JIT/grad transformations
=> neural networks in JAX

Equinox brings more power to your model building in JAX.
Represent parameterised functions as data, and use filtered transformations for powerful fine-grained control of the model-building process.

Equinox is half tech-demo, half neural network library.

Equinox in brief

Building neural networks

Build models using a PyTorch-like class based API without sacrificing JAX-like functional programming.

In particular, without extra complexity like class-to-functional transformations, custom notions of parameter groups, or specially wrapped library.jits and library.grads, like many libraries have.

Equinox is a tiny library -- no behind-the-scenes magic, guaranteed. The elegance of Equinox is its selling point in a world that already has Haiku, Flax etc.

Technical contributions

Equinox represents parameterised functions as data. That is, you can represent your whole model (parameters, buffers, forward pass, etc.) as a PyTree. Parameterised functions can be passed in and out of higher-order functions -- like passing models to jax.vmap, vmap'd functions to loss functions, or loss functions to JIT and grad.

Equinox additionally offers thin wrappers around jax.jit/jax.grad that understand the PyTree structure of their inputs: you can JIT/differentiate a single leaf, not just a whole argument. (We don't offer this for jax.vmap because interestingly jax.vmap offers this already.)

There's some similarities to existing libraries (like the structs of flax.linen or the functors of Flux.jl), but to the best of my knowledge Equinox offers something genuinely new to the JAX framework.

Installation

pip install git+https://github.com/patrick-kidger/equinox.git

Requires Python 3.7+ and JAX 0.2.18+.

Quick example

import equinox as eqx
import functools as ft, jax, jax.numpy as jnp, jax.random as jrandom

# Define our model. `Module` subclasses are both functions and data, so we can pass them into higher
# order functions like vmap/jit/grad, or our loss function later.
# There's no magic in `Module`. Pretty much all it does is just register your class as PyTree node.
class LinearOrIdentity(eqx.Module):
    weight: jnp.ndarray
    flag: bool

    def __init__(self, in_features, out_features, flag, key):
        self.weight = jrandom.normal(key, (out_features, in_features))
        self.flag = flag

    def __call__(self, x):
        if self.flag:
            return x
        return self.weight @ x

# We use the fact that our model is data, by passing it in as an argument to the loss.
# There's no magic here: `model` is a PyTree like any other.
#
# We use filtered transformations to unpack its data and select just the leaves we want to 
# JIT+differentiate. (In this case, all floating-point JAX arrays -- `weight` but not `flag`.)
# There's no magic here: filtered transformations act on any kind of PyTree.
#
# Equinox is JAX-friendly. If you want to differentiate everything, just use `jax.jit` and `jax.grad`.
@ft.partial(eqx.jitf, filter_fn=eqx.is_inexact_array)
@ft.partial(eqx.gradf, filter_fn=eqx.is_inexact_array)
def loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jnp.mean((y - pred_y) ** 2)

modelkey, xkey, ykey = jrandom.split(jrandom.PRNGKey(0), 3)
model = LinearOrIdentity(2, 3, flag=False, key=modelkey)
x, y = jrandom.normal(xkey, (100, 2)), jrandom.normal(ykey, (100, 3))
grads = loss(model, x, y)

This quick example exposes you to the two main concepts in Equinox: callable PyTrees and filtered transformations. Together, they're very powerful.

Callable PyTrees

This is just some methods attached to a PyTree. (In this case it's the __call__ method of a Module subclass.) All subclassing Module really does is just automatically register your class with JAX as a custom PyTree node; there's no magic here.

The PyTree structure holds the data (parameters, buffers, submodules, boolean flags, even arbitrary Python objects). The methods on the class define operations parameterised by that data -- in this case and in particular, the forward pass through a model.

This gives a way to represent parameterised functions as data: and as such, they're suitable for passing in and out of JAX functions. This is what we do when passing the model instance to the loss function.

Footnote: callable PyTrees actually aren't anything special -- the build-in Python methods on lists and dictionaries are another example of callable PyTrees.

Filtered transformations

The one issue with putting everything about a model into a single PyTree is that this might not contain just trainable parameters. The above example includes a boolean flag, for example. We certainly can't differentiate this, and we may or may not wish to JIT trace/static this.
In general we might have arbitrary Python objects, or perhaps JAX arrays that are buffers rather than trainable parameters.

Enter filtered transformations. These are equinox.jitf and equinox.gradf, which are very thin wrappers around jax.jit and jax.grad. Instead of specifying argnums to JIT/differentiate, we instead pass a filter that determines which PyTree leaves -- not just whole arguments -- to JIT/differentiate.

These aren't "a way to make JIT/grad work with model states" like many libraries have. They are general operations on PyTrees, and nothing about Module is special-cased.

  • For one thing, we don't need to special-case anything: Module is just a PyTree like any other.
  • For another, if you don't want to filter out anything at all, then don't: use jax.jit and jax.grad directly and they'll work just fine.

This gives a powerful fine-grained way control JIT and autodifferentiation.

Integrates smoothly with JAX

There's nothing special about Equinox modules. They're just PyTrees.

There's nothing special about filtered transformations. They just operate on PyTrees.

Equinox is all just regular JAX -- PyTrees and transformations! Together, these two pieces allow us to specify complex models in JAX-friendly ways.

Examples

  • train_mlp.py gives a short example that introduces equinox.jitf and equinox.gradf. These will be used to select the parameters of an MLP and train them.

  • frozen_layer.py demonstrates how this approach really shines: some of the parameters will be trained, some of them will be frozen, but all of them will be efficiently JIT-traced.

  • build_model.py demonstrates how to build parameterised-functions-as-data using equinox.Module. In particular we'll construct an MLP from scratch, and then pass it into higher-order functions like JIT and grad in order to train it. This allows us to produce models using a familiar class-based syntax, that are also functional and integrate directly with JAX's JIT/autograd.

  • train_rnn.py trains an RNN on a toy clockwise/anticlockwise spiral classification problem. This demonstrates the use of jax.lax.scan with Equinox. (It just works, no tricks required.)

API

Full API list

# Filtered transformations       # Filters
equinox.jitf                     equinox.is_inexact_array
equinox.gradf                    equinox.is_array_like
equinox.value_and_grad_f
                                 # Neural networks
# Module                         equinox.nn.Linear
equinox.Module                   equinox.nn.Identity
                                 equinox.nn.Dropout
# Utilities                      equinox.nn.GRUCell
equinox.apply_updates            equinox.nn.LSTMCell
equinox.tree_at                  equinox.nn.Sequential
equinox.tree_equal               equinox.nn.MLP

Filtered transformations

equinox.jitf(fun, *, filter_fn=None, filter_tree=None, **kwargs)

Wraps jax.jit.

  • fun is a pure function to JIT compile.
  • filter_fn is a callable Any -> bool. It will be called on every leaf of every PyTree that is inputted to fun. If it returns True, the leaf will be traced. It returns False, the leaf with be treated as static. Mutually exclusive with filter_tree.
  • filter_tree is a tree, or tuple of trees, of the same length as the number of inputs. (Or if static_argnums is passed, the number of inputs not already marked static via static_argnums.) It must have the exact same tree structure as the inputs. Every leaf must be either True or False. Each leaf of filter_tree is matched up against the corresponding input: if it is True the leaf will be traced; it it is False the leaf will be treated as static. Mutually exclusive with filter_tree.
  • **kwargs are the usual other arguments to jax.jit, like static_argnums. In particular, a leaf will be marked static if either (a) it is filtered as being so, or (b) it is part of a PyTree that is marked through static_argnums.

Precisely one of filter_fn or filter_tree must be passed.
See also equinox.is_array_like as usually a good choice of filter_fn: this will trace everything that can possible be traced, with everything else static.
See also equinox.tree_at for an easy way to create the filter_tree argument.

equinox.gradf(fun, *, filter_fn=None, filter_tree=None, **kwargs)

Wraps jax.grad.

  • fun is a pure function to JIT compile.
  • filter_fn is a callable Any -> bool. It will be called on every leaf of every PyTree that is marked as potentially requiring gradient via argnums. If it returns True, the leaf will be differentiated. If it returns False, the leaf will not be differentiated. Mutually exclusive with filter_tree.
  • filter_tree is a tree, or tuple of trees, of the same length as the number of inputs marked as potentially requiring gradient via argnums. It must have the exact same tree structure as the inputs. Every leaf must be either True or False. Each leaf of filter_tree is matched up against the corresponding input: if it is True the leaf will be differentiated; if it is False the leaf will not be differentiated. Mutually exclusive with filter_fn.
  • **kwargs are the usual other argments to jax.grad, like argnums. In particular, a leaf will only be differentiated if (a) it is filtered as being so, and (b) it is part of a PyTree that is marked through argnums.

Precisely one of filter_fn or filter_tree must be passed.
See also equinox.is_inexact_array as usually a good choice of filter_fn: this will differentiate all floating-point arrays.
See also equinox.tree_at for an easy way to create the filter_tree argument.

Note that as the returned gradients must have the same structure as the inputs, then all nondifferentiable components of the input PyTrees will have gradient None. Doing a simple jax.tree_map(lambda m, g: m - lr * g, model, grad) will fail. As such Equinox provides equinox.apply_updates as a simple convenience: it will only apply the update if the gradient is not None. See below.

equinox.value_and_grad_f(fun, *, filter_fn=None, filter_tree=None, **kwargs)

Wraps jax.value_and_grad. Arguments are as equinox.gradf.

Filters

Any function Any -> bool can be used as a filter. We provide some convenient common choices.

equinox.is_inexact_array(element)

Returns True if element is a floating point JAX array (but not a NumPy array).

equinox.is_array_like(element)

Returns True if element can be interpreted as a JAX array. (i.e. does jax.numpy.array throw an exception or not.)

Module

equinox.Module

Base class; create your model by inheriting from this.

Specify all its attributes at the class level (identical to dataclasses). This defines its children in the PyTree.

class MyModule(equinox.Module):
    weight: typing.Any
    bias: typing.Any
    submodule: Module

In this case a default __init__ method is provided, which just fills in these attributes with the argments passed: MyModule(weight, bias, submodule) or MyModule(weight=weight, bias=bias, submodule=submodule). Alternatively you can provide an __init__ method yourself. (For example to specify dimension sizes instead of raw weights.) By the end of __init__, every attribute must have been assigned.

class AnotherModule(equinox.Module):
    weight: Any

    def __init__(self, input_size, output_size, key):
        self.weight = jax.random.normal(key, (output_size, input_size))

After initialisation then attributes cannot be modified: models are immutable as per functional programming. (Parameter updates are made by creating a new model, not by mutating parameters in-place; see for example train_mlp.py.)

It is typical to also create some methods on the class. As self will be an input parameter -- treated as a PyTree -- then these methods will get access to the attributes of the instance. Defining __call__ gives an easy way to define a forward pass for a model:

class LinearWithoutBias(equinox.Module):
    weight: Any

    def __call__(self, x):
        return self.weight @ x

If defining a method meth, then take care not to write instance = MyModule(...); jax.jit(instance.meth)(...). (Or similarly with jax.grad, equinox.jitf etc.) This is because instance.meth is not a pure function as it already has the self parameter passed implicitly. Instead do either jax.jit(MyModule.meth)(instance, ...) or

@jax.jit
def func(instance, args):
    instance.meth(args)
    # Also use this pattern with instance(args) if you defined `__call__` instead of `meth`.

Utilities

equinox.apply_updates(model, updates)

Performs a training update to a model.

  • model must be a PyTree;
  • updates must be a PyTree with the same structure.

It essentially performs jax.tree_map(lambda m, u: m + u, model, updates). However anywhere updates is None then no update is made at all, so as to handle nondifferentiable parts of model.

The returned value is the updated model. (model is not mutated in place, as is usual in JAX and functional programming.)

To produce updates, it is typical to take the gradients from the loss function, and then adjust them according to any standard optimiser; for example Optax provides optax.sgd or optax.adam.

equinox.tree_at(where, pytree, replace=_sentinel, replace_fn=_sentinel)

Modifies an existing tree, and returns the modified tree. (Like .at for "in place modifications" of JAX arrays.)

  • where is a callable PyTree -> Leaf or PyTree -> Tuple[Leaf, ...]. It should consume a PyTree of the same shape as pytree, and return the leaf or leaves that should be replaced. For example where=lambda mlp: mlp.layers[-1].linear.weight.
  • pytree is the existing PyTree to modify.
  • replace should either be a single element, or a tuple of the same length as returned by where. This specifies the replacements to make at the locations specified by where. Mutually exclusive with replace_fn.
  • replace_fn should be a function Leaf -> Any. It will be called on every leaf replaced using where. The return value from replace_fn will be used in its place. Mutually exclusive with replace.

For example this can be used to specify the weights of a model to train or not train:

trainable = jax.tree_map(lambda _: False, model)
trainable = equinox.tree_at(lambda mlp: mlp.layers[-1].linear.weight, model, replace=True)
equinox.gradf(..., filter_tree=trainable)
equinox.tree_equal(*pytrees)

Returns True if all PyTrees in the list are equal. All arrays must have the same shape, dtype, and values. JAX arrays and NumPy arrays are not considered equal.

Neural network library

Equinox includes a small neural network library, mostly as a tech demo for how the rest of the library can be used. Its API is modelled after PyTorch.

equinox.nn.Linear(in_features, out_features, bias=True, *, key)(input)
equinox.nn.Identity(*args, **kwargs)(input)  # args and kwargs are ignored
equinox.nn.Dropout(p=0.5, deterministic=False)(input, *, key=None, deterministic=None)
equinox.nn.GRUCell(input_size, hidden_size, bias=True, *, key)(input, hidden)
equinox.nn.LSTMCell(input_size, hidden_size, bias=True, *, key)(input, hidden)
equinox.nn.Sequential(layers)(input, *, key=None)
equinox.nn.MLP(in_size, out_size, width_size, depth,
               activation=jax.nn.relu, final_activation=lambda x: x, *, key)(input)

These all behave in the way you expect. The key arguments are used to generate the random initial weights, or to generate randomness on the forward pass of stochastic layers like Dropout.

The Dropout(deterministic=...)(deterministic=...) options determines whether to have the layer act as the identity function, as is commonly done with dropout during inference time. The call-time deterministic takes precendence if it passed; otherwise the init-time deterministic is used. (Note that because models are PyTrees, you can modify the init-time deterministic flag using equinox.tree_at. This is perfectly fine, and might be handy if it's easier than using the call-time flag.)

The MLP(final_activation=...) option determines any final activation function to apply after the last layer. (In some cases it is desirable for this to be different to the activation used in the main part of the network.)

Comments
  • Add some kind de/serialisation?

    Add some kind de/serialisation?

    Equinox models are just PyTrees so they should be very easy to serialise/deserialise; just save the PyTree to disk in whatever way is desired. It might be worth adding some library functions for this just for convenience. Perhaps checking the device of JAX arrays etc?

    This should respect the get_state/set_state stuff that's being put together.

    In addition, there should be a version of get_state which inlines its state in the call graph, for faster inference.

    feature 
    opened by patrick-kidger 13
  • Support Buffer Donation in filter_jit

    Support Buffer Donation in filter_jit

    Try to close #3

    Hi @patrick-kidger

    Jax supports buffer donation on CPU after jaxlib >0.3.22, it is necessary for many in-place updation.

    I've drafted a proposal in here.

    Details:

    • Adds donate_default, donate_args, donate_kwargs, and donate_fn for control of buffer donation;
    • Adds Tests.

    Considerations:

    • It would be trivial to make these arugments work only with parameters marked as traced.
    • We need the ability to control buffer donation in function, args and kwargs, which is similar to handling Tracing.

    More(not in this PR): Supports buffer donation in filter_pmap

    opened by uuirs 10
  • How to get outputs from an intermediate layer?

    How to get outputs from an intermediate layer?

    Hi @patrick-kidger,

    I was looking into obtaining outputs from some specific layer of the network without having to modify the network __call__ method or the network itself.

    One way I think of achieving this is using id_tap from jax.experimental.hcb.... The solution would be to wrap the target layer of the network in a class with integrated id_tap. Something like net = eqx.tree_at(lambda net: net.layer_0, net, Wrapper(net.layer_0)).

    1. Would this be the best way to capture intermediate outputs?
    2. A network without BN seems to be working fine (prints normally with id_print). However, with BatchNorm, throws an error
    >     flat_results = outside_call_p.bind(*flat_args, **params)
    E     TypeError: _outside_call_batching_rule() missing 1 required keyword-only argument: 'result_treedef'
    
    ...python3.8/site-packages/jax/experimental/host_callback.py:763: TypeError
    
    question 
    opened by paganpasta 10
  • Default serialisation fails for `BatchNorm`.

    Default serialisation fails for `BatchNorm`.

    Hi,

    The defaut serialisation fails when a model with BatchNorm is serialised. A small test script executed on dev branch.

    
    def test_serialise_bn(getkey):
        net = eqx.nn.Sequential(
            [
                eqx.experimental.BatchNorm(3, axis_name="batch"),
            ]
        )
    
        eqx.tree_serialise_leaves('/tmp/net.eqx', net)
    
        assert True
    

    with the error

    _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
    ../equinox/serialisation.py:183: in tree_serialise_leaves
        jtu.tree_map(_serialise, filter_spec, pytree)
    ../../../miniconda3/envs/equinox/lib/python3.8/site-packages/jax/_src/tree_util.py:201: in tree_map
        return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    ../../../miniconda3/envs/equinox/lib/python3.8/site-packages/jax/_src/tree_util.py:201: in <genexpr>
        return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    ../equinox/serialisation.py:181: in _serialise
        jtu.tree_map(__serialise, x, is_leaf=is_leaf)
    ../../../miniconda3/envs/equinox/lib/python3.8/site-packages/jax/_src/tree_util.py:201: in tree_map
        return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    ../../../miniconda3/envs/equinox/lib/python3.8/site-packages/jax/_src/tree_util.py:201: in <genexpr>
        return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    ../equinox/serialisation.py:179: in __serialise
        spec(f, y)
    ../equinox/serialisation.py:50: in default_serialise_filter_spec
        value, _, _ = x.unsafe_get()
    ../equinox/experimental/stateful.py:112: in unsafe_get
        return _state_cache[self._obj]
    _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
    
    self = <WeakKeyDictionary at 0x7fc6deff8d90>
    key = <equinox.experimental.stateful._IndexObj object at 0x7fc6deec12b0>
    
        def __getitem__(self, key):
    >       return self.data[ref(key)]
    E       KeyError: <weakref at 0x7fc736f7e590; to '_IndexObj' at 0x7fc6deec12b0>
    
    ../../../miniconda3/envs/equinox/lib/python3.8/weakref.py:383: KeyError
    
    bug 
    opened by paganpasta 10
  • Unexpected (?) BatchNorm behavior: model's flattened form changes through iterations

    Unexpected (?) BatchNorm behavior: model's flattened form changes through iterations

    Hi, This might be something that is already known, or perhaps I'm not using the library as intended. Apologies in advance if that's the case. First some background info:

    I'm writing code for a scenario that features a form of pipeline parallelism. I have a model, which I split in parts/modules, and each part is run on a different device. The results of each part are passed on to the next in a loop. The model features BatchNorm (I'm trying to implement some known results that use it, although I'm now aware that BatchNorm is finicky in Jax).

    As a test case, I feed N batches of the exact same samples in the first N iterations, then do some updates on my model. I repeat this procedure with a new batch, which is fed repeatedly for the next N iterations. As a sanity check, in every N consecutive iterations, the model should output the same values. This is not the case, though, and I think BatchNorm might be the issue.

    To debug, I thought I'd check whether the model's parameters change during these N iterations, by flattening it and comparing it to its previous version. However, I run into errors regarding "List arity mismatch". I have a very simplified example that exhibits this sort of behavior below. To simulate my use case, the second module/part is only run from the third iteration onward. Even for i = 1, the two model "versions" are not comparable (one was before running anything, the second after running the first module/part).

    If I remove the BatchNorm layers there are no errors, which leads me to believe that the fact that it modifies its state is the problem. Am I using something wrong here? If not, how can I work around this, and what could possibly cause my model's output to be different for the same inputs?

    import equinox as eqx
    import jax
    import jax.numpy as jnp
    import jax.random as jr
    
    key = jr.PRNGKey(0)
    mkey, dkey = jr.split(key)
    model_pt1 = eqx.nn.Sequential([
        eqx.nn.Linear(in_features=3, out_features=4, key=mkey),
        eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
    ])
    model_pt2 = eqx.nn.Sequential([   
        eqx.nn.Linear(in_features=4, out_features=4, key=mkey),
        eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
    ])
    
    model_combined = [model_pt1, model_pt2]
    
    x = jr.normal(dkey, (10, 3))
    flattened_model, _ = jax.tree_util.tree_flatten(eqx.filter(model_combined, eqx.is_inexact_array))
    for i in range(10):
        prev_flattened_model = flattened_model
        flattened_model, _ = jax.tree_util.tree_flatten(eqx.filter(model_combined, eqx.is_inexact_array))
        
        diffs = jax.tree_map(lambda x, y : jnp.abs(x - y), flattened_model, prev_flattened_model)
        y1 = jax.vmap(model_pt1, axis_name="batch")(x)
        if(i >= 2):
            y2 = jax.vmap(model_pt2, axis_name="batch")(y1)
    
    question 
    opened by geomlyd 9
  • Added `eqx.experimental.noinline`

    Added `eqx.experimental.noinline`

    TL;DR: XLA sub-graphs!

    Background

    At present, JAX inlines the entire computation into a single XLA graph.

    However, many scientific computing applications involve defining some modestly complicated function and then calling this function numerous times in different contexts. For example the vector field for an ODE must be traced 22 times when using the Kvaerno5 solver with automatic initial step size selection. (In ways that cannot easily be tidied up into a lax.scan or similar.)

    Inlining without awareness of the repeated structure means the compiler is less efficient than it could be. I know of current examples with compile times about an hour long.

    To support this use case there's been talk for quite a while about adding support to JAX or XLA for sub-programs, e.g. https://github.com/google/jax/issues/10974 https://github.com/google/jax/issues/4572 https://github.com/google/jax/issues/3847 https://github.com/google/jax/issues/9298

    no-inline decorator

    Introducing equinox.experimental.noinline. This decorator places the function in a separate XLA computation graph, and links it up with the main one via jax.experimental.host_callback.call. The decorated function is only traced once; only one copy of it exists as a jaxpr; it is only compiled once. It can still be transformed via grad, vmap etc.

    Running the included benchmark benchmarks/noinline.py we obtain a reduction in compile time 36 seconds -> 25 seconds, at the expense of a large runtime increase, 0.002 seconds -> 0.6 seconds. In practice that's still a good net saving (36 seconds -> 25.6 seconds) in the common use-case that you're developing + debugging your program.

    Going further and switching the solver in the benchmark from dfx.Kvaerno5() to dfx.Dopri8() gives even better results: a compile time reduction of 23 seconds -> 8 seconds (!), with a runtime increase of 0.002 seconds -> 0.16 seconds. (I chose not to implement this as the default benchmark, just because I have other unrelated plans for improving the compile time of Dopri8.)

    Limitations

    1. All the awful hackery and monkey-patching of JAX internals needed to make this work.
    2. This will only be efficient on the CPU. On the GPU it'll entail copies to and from the device. However I speculate that this isn't actually necessary, and may just be a limitation of our current use of host_callback?
    3. The runtime performance has cratered. I speculate a lot of that cost is due to the back-and-forth switching via Python, again due to our use of host_callback. (Flame graphs TBD.) Possibly also something GIL related?

    Nonetheless, our main use-case is on the CPU and the overall compile-time improvements on the benchmark represent compile speed improvements of 1.5x to 3x, which is enough to make me happy. This is something we're looking forward to relying on as those 1+ hour compile times are really biting us.

    CC

    @shoyer and @YouJiacheng as I know you've both wanted this functionality in the past. @federicov (+team) for using this.

    Also speculatively tagging @gnecula (I have it in my head that you're behind host_callback?) @mattjj for possible interest. (Feel free to ignore this.)

    opened by patrick-kidger 9
  • Implementing Batch Normalization

    Implementing Batch Normalization

    In Flax, Batch Normalization is a bit finicky since each call to apply requires marking batch_stats as mutable and updating the batch_stats afterward.

    bn = flax.linen.BatchNorm(use_running_average=True)
    
    x = jnp.arange(24).reshape(3, 6)
    
    vars = bn.init(random.PRNGKey(0), x)
    
    # Mark the batch stats as mutable so we can update them in the variable dictionary
    x_normed, mutated_vars = bn.apply(vars, x, mutable=['batch_stats'])
    
    vars = {**vars, **mutated_vars}  # Update the variables with our diff
    
    x_normed2, mutated_vars2 = bn.apply(vars, x, mutable=['batch_stats'])
    

    How could this be implemented as a Module in Equinox? I'm happy to submit an implementation given some guidance.

    opened by marcelroed 9
  • Add attention functions and tests

    Add attention functions and tests

    Adds dot_product_attention_weights and dot_product_attention functions and tests

    Design considerations:

    • dot_product_attention and dot_product_attention_weights don't take multi-head inputs -- instead attention heads are vmap'd over in MultiheadAttention. This allows for greater flexibility when creating other types of attention modules
    • To simplify the dot_product_attention signature -- dropout_fn is added as a single argument callable, which should close over the dropout arguments like key and inference. The alternative I think would be to add a functional version of dropout and add its arguments to dot_product_attention, however this would make changing the dropout rate after initializing the module less intuitive -- since dropout rate would have to be an attribute of MultiheadAttention.
    • mask shape check is kept inside dot_product_attention_weights. The downside to this is that errors raised inside vmap'd functions are less obvious -- i.e. if the heads don't match then vmap function will raise an error. The alternative is to pull the shape check out and put it back in MultiheadAttention
    opened by jenkspt 8
  • eqx.filter_{vmap,pmap}(out=...) not experimental

    eqx.filter_{vmap,pmap}(out=...) not experimental

    1. Previously using a callable out parameter was experimental for filter_vmap -- because it monkey-patched JAX internals -- and unavailable for filter_pmap. It has now been updated to work for both, and using only the public JAX API. (Hurrah.)

    2. Drive-by: Added eqx.filter_eval_shape as it looked like this was going to useful as part of implementing the previous feature. In the end this wasn't the case, but we get a new feature anyway.

    3. Drive-by: Fixed a crash bug when filter-jit'ing a partial-wrapped function whilst using a PyTree as its filter spec (fn).

    Closes #115.

    CC @jatentaki WDYT?

    opened by patrick-kidger 8
  • Support for adaptive average pooling?

    Support for adaptive average pooling?

    Hi again,

    I was trying out equinox for some computer vision experiments and found that at the moment there is no support for adaptive average pooling. A similar functionality exists in Pytorch. So I just wanted to check if it is something you intend to add later.

    Thanks.

    feature 
    opened by paganpasta 8
  • Added max and avg pool, alongside tests and docs

    Added max and avg pool, alongside tests and docs

    I currently use einops.rearrange so that the pooling layers and conv layers take the same tensor format as input, BCHW. For some reason lax.conv_general_dilated and lax.reduce_window are BCHW and BHWC respectively. Can do this not using einops if you would prefer to not have the extra dependency.

    opened by Benjamin-Walker 8
  • PyTorch DataLoading issue (with Equinox ?)

    PyTorch DataLoading issue (with Equinox ?)

    Hello,

    I encountered an issue (and a fix) about loading data with a PyTorch DataLoader, when used with JAX (and I think Equinox). I am not sure this belongs exactly here so please feel free to tell me and I will move it somewhere else. I also mention #137, since I feel this is related.

    The setup:

    I train a small MLP on a classification task on MNIST. I use the data loader given in the JAX documentation(https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html) notably the NumpyLoader and the associated collate function. When I run the training script I get incoherent losses and accuracies. if I use the standard PyTorch DataLoader, I do not face the issue.

    I link two files (one failing and one passing). If anyone has an idea on why it does not work, I would love to know. I hope this can also help others since I have been looking into it for 2 days now.

    Thank you for any input

    link to file : https://gist.github.com/pablo2909/3a2cec869a43421859520750990f263e

    question 
    opened by pablo2909 1
  • PReLU Activation

    PReLU Activation

    Hi @patrick-kidger, I am a big fan of Equinox, thank you for publishing this framework. I am currently working on a project and decided to give a PReLU a try, Therefore, I was wondering, if it was possible to add PReLU Activation function to a some module like nn.activations.

    My current implementation, is as follows:

    class PReLU(eqx.Module):
        negative_slope: jaxtyping.Array
    
        def __init__(self, alpha):
            self.negative_slope = jnp.array((alpha,))
    
        def __call__(self, x):
            return jax.numpy.where(x >= 0, x, self.negative_slope * x)
    
    feature 
    opened by enver1323 2
  • using equinox with xmap

    using equinox with xmap

    Hello, Thank you for publishing a fantastic library! Is it possible to use named axes operations (xmap) within init and call of nn .module and then xmap over all created module? I suppose It is as it is just a pytree but I do not have an idea how

    question 
    opened by jakubMitura14 2
  • `BatchNorm` raises `RuntimeError: Cannot get state before it has been set` when used in a scan

    `BatchNorm` raises `RuntimeError: Cannot get state before it has been set` when used in a scan

    BatchNorm raises RuntimeError: Cannot get state before it has been set when used in a scan. Minimal working example is given below. Note that calling vmapped_fun twice and decorating with eqx.filter_grad are both necessary and that swapping jax.lax.map(bn, xs) to jax.vmap(bn)(xs) doesn't raise.

    import jax
    import jax.numpy as jnp
    import jax.random as jrandom
    import equinox as eqx
    
    def fun(bn, xs):
        return jax.lax.map(bn, xs)
    
    @eqx.filter_grad
    def vmapped_fun(bn, xss):
        return jnp.mean(jax.vmap(fun, (None, 0), axis_name="batch")(bn, xss))
    
    if __name__=="__main__":
        bn = eqx.experimental.BatchNorm(10, axis_name="batch")
        inp = jrandom.normal(jrandom.PRNGKey(0), (32, 20, 10))
    
        grad = vmapped_fun(bn, inp)
        grad = vmapped_fun(bn, inp)
    

    outputs

    ERROR:jax.experimental.host_callback:Outside call <jax.experimental.host_callback._CallbackWrapper object at 0x11a530550> threw exception Cannot get state before it has been set.
    Traceback (most recent call last):
      File "/Users/andriusovsianas/repos/test/test.py", line 18, in <module>
        grad = vmapped_fun(bn, inp)
      File "/Users/andriusovsianas/repos/equinox/equinox/grad.py", line 53, in __call__
        value, grad = __self._fun_value_and_grad(*args, **kwargs)
      File "/Users/andriusovsianas/repos/equinox/equinox/grad.py", line 40, in __call__
        return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/api.py", line 1167, in value_and_grad_f
        ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/api.py", line 2656, in _vjp
        out_primal, out_vjp = ad.vjp(
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/ad.py", line 135, in vjp
        out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/ad.py", line 124, in linearize
        jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
        return func(*args, **kwargs)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 767, in trace_to_jaxpr_nounits
        jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/Users/andriusovsianas/repos/equinox/equinox/grad.py", line 37, in fun_value_and_grad
        return __self._fun(_x, *_args, **_kwargs)
      File "/Users/andriusovsianas/repos/test/test.py", line 11, in vmapped_fun
        return jnp.mean(jax.vmap(fun, (None, 0), axis_name="batch")(bn, xss))
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/api.py", line 1682, in vmap_f
        out_flat = batching.batch(
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/Users/andriusovsianas/repos/test/test.py", line 7, in fun
        return jax.lax.map(bn, xs)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 1730, in map
        _, ys = scan(g, (), xs)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 275, in scan
        out = scan_p.bind(*consts, *in_flat,
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 1000, in scan_bind
        return core.AxisPrimitive.bind(scan_p, *args, **params)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py", line 2444, in bind
        return self.bind_with_trace(top_trace, args, params)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py", line 332, in bind_with_trace
        out = trace.process_primitive(self, map(trace.full_raise, args), params)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/batching.py", line 350, in process_primitive
        val_out, dim_out = batcher_primitive(vals_in, dims_in, **params)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 755, in _scan_batching_rule
        outs = scan_p.bind(
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 1000, in scan_bind
        return core.AxisPrimitive.bind(scan_p, *args, **params)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py", line 2444, in bind
        return self.bind_with_trace(top_trace, args, params)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py", line 332, in bind_with_trace
        out = trace.process_primitive(self, map(trace.full_raise, args), params)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/ad.py", line 310, in process_primitive
        primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 485, in _scan_jvp
        out_flat = scan_p.bind(
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 1000, in scan_bind
        return core.AxisPrimitive.bind(scan_p, *args, **params)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py", line 2444, in bind
        return self.bind_with_trace(top_trace, args, params)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py", line 332, in bind_with_trace
        out = trace.process_primitive(self, map(trace.full_raise, args), params)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 207, in process_primitive
        return custom_partial_eval_rules[primitive](self, *tracers, **params)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 595, in _scan_partial_eval
        out_known = scan_p.bind(
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 1000, in scan_bind
        return core.AxisPrimitive.bind(scan_p, *args, **params)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py", line 2444, in bind
        return self.bind_with_trace(top_trace, args, params)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py", line 332, in bind_with_trace
        out = trace.process_primitive(self, map(trace.full_raise, args), params)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py", line 712, in process_primitive
        return primitive.impl(*tracers, **params)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/dispatch.py", line 115, in apply_primitive
        return compiled_fun(*args)
      File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/dispatch.py", line 895, in _execute_compiled
        out_flat = compiled.execute(in_flat)
    jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: RuntimeError: Cannot get state before it has been set
    
    At:
      /Users/andriusovsianas/repos/equinox/equinox/experimental/stateful.py(355): _get_state_hcb
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/experimental/host_callback.py(726): __call__
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/experimental/host_callback.py(1295): _outside_call_run_callback
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/experimental/host_callback.py(1164): wrapped_callback
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/mlir.py(1579): _wrapped_callback
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/mlir.py(1604): _wrapped_callback
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/dispatch.py(895): _execute_compiled
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/dispatch.py(115): apply_primitive
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py(712): process_primitive
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py(332): bind_with_trace
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py(2444): bind
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py(1000): scan_bind
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py(595): _scan_partial_eval
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/partial_eval.py(207): process_primitive
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py(332): bind_with_trace
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py(2444): bind
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py(1000): scan_bind
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py(485): _scan_jvp
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/ad.py(310): process_primitive
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py(332): bind_with_trace
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py(2444): bind
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py(1000): scan_bind
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py(755): _scan_batching_rule
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/batching.py(350): process_primitive
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py(332): bind_with_trace
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py(2444): bind
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py(1000): scan_bind
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py(275): scan
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/traceback_util.py(162): reraise_with_filtered_traceback
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py(1730): map
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/traceback_util.py(162): reraise_with_filtered_traceback
      /Users/andriusovsianas/repos/test/test.py(7): fun
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/linear_util.py(167): call_wrapped
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/api.py(1682): vmap_f
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/traceback_util.py(162): reraise_with_filtered_traceback
      /Users/andriusovsianas/repos/test/test.py(11): vmapped_fun
      /Users/andriusovsianas/repos/equinox/equinox/grad.py(37): fun_value_and_grad
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/linear_util.py(167): call_wrapped
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/partial_eval.py(767): trace_to_jaxpr_nounits
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/profiler.py(314): wrapper
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/ad.py(124): linearize
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/ad.py(135): vjp
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/api.py(2656): _vjp
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/api.py(1167): value_and_grad_f
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/traceback_util.py(162): reraise_with_filtered_traceback
      /Users/andriusovsianas/repos/equinox/equinox/grad.py(40): __call__
      /Users/andriusovsianas/repos/equinox/equinox/grad.py(53): __call__
      /Users/andriusovsianas/repos/test/test.py(18): <module>
    
    The stack trace below excludes JAX-internal frames.
    The preceding is the original exception that occurred, unmodified.
    
    --------------------
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "/Users/andriusovsianas/repos/test/test.py", line 18, in <module>
        grad = vmapped_fun(bn, inp)
      File "/Users/andriusovsianas/repos/equinox/equinox/grad.py", line 53, in __call__
        value, grad = __self._fun_value_and_grad(*args, **kwargs)
      File "/Users/andriusovsianas/repos/equinox/equinox/grad.py", line 40, in __call__
        return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs)
      File "/Users/andriusovsianas/repos/equinox/equinox/grad.py", line 37, in fun_value_and_grad
        return __self._fun(_x, *_args, **_kwargs)
      File "/Users/andriusovsianas/repos/test/test.py", line 11, in vmapped_fun
        return jnp.mean(jax.vmap(fun, (None, 0), axis_name="batch")(bn, xss))
      File "/Users/andriusovsianas/repos/test/test.py", line 7, in fun
        return jax.lax.map(bn, xs)
    jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: RuntimeError: Cannot get state before it has been set
    
    At:
      /Users/andriusovsianas/repos/equinox/equinox/experimental/stateful.py(355): _get_state_hcb
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/experimental/host_callback.py(726): __call__
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/experimental/host_callback.py(1295): _outside_call_run_callback
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/experimental/host_callback.py(1164): wrapped_callback
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/mlir.py(1579): _wrapped_callback
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/mlir.py(1604): _wrapped_callback
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/dispatch.py(895): _execute_compiled
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/dispatch.py(115): apply_primitive
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py(712): process_primitive
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py(332): bind_with_trace
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py(2444): bind
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py(1000): scan_bind
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py(595): _scan_partial_eval
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/partial_eval.py(207): process_primitive
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py(332): bind_with_trace
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py(2444): bind
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py(1000): scan_bind
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py(485): _scan_jvp
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/ad.py(310): process_primitive
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py(332): bind_with_trace
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py(2444): bind
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py(1000): scan_bind
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py(755): _scan_batching_rule
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/batching.py(350): process_primitive
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py(332): bind_with_trace
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/core.py(2444): bind
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py(1000): scan_bind
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py(275): scan
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/traceback_util.py(162): reraise_with_filtered_traceback
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py(1730): map
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/traceback_util.py(162): reraise_with_filtered_traceback
      /Users/andriusovsianas/repos/test/test.py(7): fun
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/linear_util.py(167): call_wrapped
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/api.py(1682): vmap_f
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/traceback_util.py(162): reraise_with_filtered_traceback
      /Users/andriusovsianas/repos/test/test.py(11): vmapped_fun
      /Users/andriusovsianas/repos/equinox/equinox/grad.py(37): fun_value_and_grad
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/linear_util.py(167): call_wrapped
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/partial_eval.py(767): trace_to_jaxpr_nounits
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/profiler.py(314): wrapper
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/ad.py(124): linearize
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/interpreters/ad.py(135): vjp
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/api.py(2656): _vjp
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/api.py(1167): value_and_grad_f
      /Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/traceback_util.py(162): reraise_with_filtered_traceback
      /Users/andriusovsianas/repos/equinox/equinox/grad.py(40): __call__
      /Users/andriusovsianas/repos/equinox/equinox/grad.py(53): __call__
      /Users/andriusovsianas/repos/test/test.py(18): <module>
    
    bug 
    opened by ciupakabra 1
  • BatchNorm raises

    BatchNorm raises "TypeError: Expected a callable value, got inf"

    I'm getting weird errors with BatchNorm. One example is the code below, where ODE parameters are optimized and the drift is a neural network with some BatchNorm layers. The error thrown is TypeError: Expected a callable value, got inf when vmapping a loss function.

    I wasn't able to reduce this to something without diffrax -- double vmaps or vmapping a scan doesn't raise exceptions. I'm guessing this is an equinox issue since without batchnorm (bn=False) there are no problems.

    import optax
    import jax
    import jax.numpy as jnp
    import jax.random as jrandom
    import jax.nn as jnn
    import equinox as eqx
    import diffrax as dx
    from tqdm import tqdm
    
    def integrate(drift, num_steps, dim, key):
    
        def f(t, y, args):
            return drift(jnp.concatenate([t[None], y]))
    
        drift_term = dx.ODETerm(f)
        solver = dx.Euler()
        y0 = jnp.zeros(dim)
    
        ts = jnp.linspace(0, 1, num_steps + 1)
        saveat = dx.SaveAt(ts=ts)
    
        sol = dx.diffeqsolve(
            drift_term,
            solver,
            0,
            1,
            1/num_steps,
            y0,
            saveat=saveat,
            max_steps=num_steps + 1,
        )
    
        return sol.ys
    
    def loss(drift, num_steps, dim, key):
        path = integrate(drift, num_steps, dim, key)
        final = path[-1]
        loss = jnp.sum(final**2)
        return loss
    
    @eqx.filter_value_and_grad
    def loss_mean(drift, num_steps, dim, key, batch_size):
        loss_vmapped = jax.vmap(loss, (None, None, None, 0), 0, axis_name="batch")
        key = jrandom.split(key, batch_size)
        return jnp.mean(loss_vmapped(drift, num_steps, dim, key))
    
    
    class Network(eqx.Module):
        net: eqx.Module
    
        def __init__(self, in_size, out_size, width, depth, *, key, bn=True):
    
            keys = jrandom.split(key, depth + 1)
            layers = []
            if depth == 0:
                layers.append(eqx.nn.Linear(in_size, out_size, key=keys[0]))
            else:
                layers.append(eqx.nn.Linear(in_size, width, key=keys[0]))
                if bn:
                    layers.append(eqx.experimental.BatchNorm(width, axis_name="batch"))
                for i in range(depth - 1):
                    layers.append(eqx.nn.Linear(width, width, key=keys[i + 1]))
                    if bn: 
                        layers.append(eqx.experimental.BatchNorm(width, axis_name="batch"))
                    layers.append(eqx.nn.Lambda(jnn.relu))
                layers.append(eqx.nn.Linear(width, out_size, key=keys[-1]))
    
            self.net = eqx.nn.Sequential(layers)
    
        def __call__(self, x):
            return self.net(x)
    
    
    if __name__=="__main__":
    
        key = jrandom.PRNGKey(0)
    
        init_drift_key, train_key = jrandom.split(key, 2)
    
        dim = 500
    
        drift = Network(dim + 1, dim, 300, 2, key=init_drift_key, bn=True)
    
        optimizer = optax.adamw(1e-4)
        opt_state = optimizer.init(eqx.filter(drift, eqx.is_inexact_array))
        
        @eqx.filter_jit
        def make_step(drift, num_steps, dim, key, batch_size, opt_state):
            loss, grads = loss_mean(drift, num_steps, dim, key, batch_size)
            updates, opt_state = optimizer.update(
                grads, opt_state, eqx.filter(drift, eqx.is_inexact_array)
            )
            drift = eqx.apply_updates(drift, updates)
            return loss, drift, opt_state
    
    
        for step in tqdm(range(100)):
            step_key = jrandom.fold_in(train_key, step)
            loss, drift, opt_state = make_step(
                drift, 10, dim, step_key, 32, opt_state
            )
    
    (env) andrius:/home/andrius/repos/test% python test.py  
      1%|▏                                                                                                                                                                                                                       | 1/100 [00:01<03:14,  1.97s/it]
    Traceback (most recent call last):
      File "/home/andrius/repos/test/test.py", line 99, in <module>
        loss, drift, opt_state = make_step(
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/jit.py", line 82, in __call__
        return __self._fun_wrapper(False, args, kwargs)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/jit.py", line 78, in _fun_wrapper
        dynamic_out, static_out = self._cached(dynamic, static)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/api.py", line 622, in cache_miss
        execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/dispatch.py", line 236, in _xla_call_impl_lazy
        return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/linear_util.py", line 303, in memoized_fun
        ans = call(fun, *args)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/dispatch.py", line 359, in _xla_callable_uncached
        return lower_xla_callable(fun, device, backend, name, donated_invars, False,
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
        return func(*args, **kwargs)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/dispatch.py", line 445, in lower_xla_callable
        jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
        return func(*args, **kwargs)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2077, in trace_to_jaxpr_final2
        jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2027, in trace_to_subjaxpr_dynamic2
        ans = fun.call_wrapped(*in_tracers_)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/jit.py", line 30, in fun_wrapped
        out = fun(*args, **kwargs)
      File "/home/andrius/repos/test/test.py", line 89, in make_step
        loss, grads = loss_mean(drift, num_steps, dim, key, batch_size)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/grad.py", line 40, in __call__
        return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/api.py", line 1167, in value_and_grad_f
        ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/api.py", line 2656, in _vjp
        out_primal, out_vjp = ad.vjp(
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/interpreters/ad.py", line 135, in vjp
        out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/interpreters/ad.py", line 124, in linearize
        jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
        return func(*args, **kwargs)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 767, in trace_to_jaxpr_nounits
        jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/grad.py", line 37, in fun_value_and_grad
        return __self._fun(_x, *_args, **_kwargs)
      File "/home/andrius/repos/test/test.py", line 43, in loss_mean
        loss_vmapped = jax.vmap(loss, (None, None, None, 0), 0, axis_name="batch")
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/api.py", line 1647, in vmap
        _check_callable(fun)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/api.py", line 181, in _check_callable
        raise TypeError(f"Expected a callable value, got {fun}")
    jax._src.traceback_util.UnfilteredStackTrace: TypeError: Expected a callable value, got inf
    
    The stack trace below excludes JAX-internal frames.
    The preceding is the original exception that occurred, unmodified.
    
    --------------------
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "/home/andrius/repos/test/test.py", line 99, in <module>
        loss, drift, opt_state = make_step(
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/jit.py", line 82, in __call__
        return __self._fun_wrapper(False, args, kwargs)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/jit.py", line 78, in _fun_wrapper
        dynamic_out, static_out = self._cached(dynamic, static)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/jit.py", line 30, in fun_wrapped
        out = fun(*args, **kwargs)
      File "/home/andrius/repos/test/test.py", line 89, in make_step
        loss, grads = loss_mean(drift, num_steps, dim, key, batch_size)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/grad.py", line 40, in __call__
        return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs)
      File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/grad.py", line 37, in fun_value_and_grad
        return __self._fun(_x, *_args, **_kwargs)
      File "/home/andrius/repos/test/test.py", line 43, in loss_mean
        loss_vmapped = jax.vmap(loss, (None, None, None, 0), 0, axis_name="batch")
    TypeError: Expected a callable value, got inf
    
    (env) andrius:/home/andrius/repos/test% pip list
    Package           Version
    ----------------- ---------------------
    absl-py           1.3.0
    chex              0.1.5
    diffrax           0.2.2
    dm-tree           0.1.7
    equinox           0.9.2
    jax               0.3.25
    jaxlib            0.3.25+cuda11.cudnn82
    jaxtyping         0.2.8
    numpy             1.23.5
    opt-einsum        3.3.0
    optax             0.1.4
    pi                0.1.2
    pip               22.3.1
    scipy             1.9.3
    setuptools        65.5.0
    toolz             0.12.0
    tqdm              4.64.1
    typeguard         2.13.3
    typing_extensions 4.4.0
    wheel             0.37.1
    
    bug question 
    opened by ciupakabra 6
Releases(v0.9.2)
  • v0.9.2(Nov 17, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Minor doc fixes by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/228
    • Allow passing file-like objects to eqx.serialise/deserialise by @jatentaki in https://github.com/patrick-kidger/equinox/pull/229
    • Fixed broken filter_closure_convert (and new JAX breaking Equinox's experimental stateful operations) by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/232

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.9.1...v0.9.2

    Source code(tar.gz)
    Source code(zip)
  • v0.9.1(Nov 15, 2022)

    New features

    These are all pretty self-explanatory!

    • equinox.filter_make_jaxpr
    • equinox.filter_vjp
    • equinox.filter_closure_convert
    • equinox.filter_pure_callback

    Also:

    • equinox.internal.debug_backward_nan(x) will print out the primal and cotangent for x, and if the cotangent has a NaN then the computation is halted.

    Bugfixes

    • equinox.{is_array, is_array_like, is_inexact_array, is_inexact_array_like} all now recognise NumPy scalars as being array types.
    • equinox.internal.{error_if, branched_error_if} are now compatible with jax.ensure_compile_time_eval.
    • equinox.internal.noinline will now no longer throw an assert error during tracing under certain edge-case conditions. (In particular, when part of the branched of a vmap'd lax.cond with batched predicate.)
    • equinox.tree_pformat now prints out jax.tree_util.Partials, and dataclass types (not instances) correctly.

    Tweaks:

    • equinox.internal.noinline is now compatible with jax.jit, i.e. a noinline-wrapped function can be passed across a jit API boundary. (Previously equinox.filter_jit was required.)
    • equinox.internal.announce_jaxpr has been renamed to equinox.internal.announce_transform.
    • equinox.internal.{nondifferentiable, nondifferentiable_backward} now take a msg argument for overriding their error messages.

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.9.0...v0.9.1

    Source code(tar.gz)
    Source code(zip)
  • v0.9.0(Nov 2, 2022)

    This is a big update. The highlight here is the new equinox.internal namespace, which contains a slew of advanced features.

    These are only "semi public". These are deliberately not in the main documentation, and exist primarily for the benefit of downstream libraries like Diffrax. But you may still have fun playing with them.

    Features

    • equinox.internal.
      • Autodiff:
        • nondifferentiable: will raise an error at trace-time if you attempt to differentiate it.
        • nondifferentiable_backward: will raise an error at trace-time if you attempt to reverse-mode differentiate it.
      • Debug tooling:
        • announce_jaxpr: will call a custom callback whenever it is traced/transformed in a jaxpr. print(<transform stack>) is the default callback.
      • Runtime errors:
        • error_if: can raise runtime errors. (Works on CPU; doesn't work on TPU. GPU support may be flaky.)
        • branched_error_if: can raise one of multiple errors, depending on a traced value.
      • Floating point manipulation:
        • nextafter: returns the next floating point number. Unlike jnp.nextafter, it is differentiable.
        • prevbefore: returns the previous floating point number. Is differentiable.
      • MLIR sub-graphs:
        • noinline: used to mark that a subcomputation should be placed in a separate computation graph, e.g. to avoid compiling the same thing multiple times if it is called repeatedly. Can also be used to iteratively recompile just parts of a computation graph, if the sub/super-graph is the only thing that changes.
      • Omega:
        • ω: nice syntax for tree arithmetic. For example (x**ω + y**ω).ω == tree_map(operator.add, x, y). Like tree-math but with nicer syntax.
      • Custom primitives:
        • filter_primitive_{def, jvp, transpose, batching, bind}: Define rules for custom primitive that accept arbitrary PyTrees; not just JAX arrays.
        • create_vprim: Autodefines batching rules for higher-order primitives, according to transform(vmap(prim)) == vmap(transform(prim)).
      • String handling:
        • str2jax: turns a string into a JAX'able object.
      • Unvmap'ing:
        • unvmap_{any, all, max}: apply reductions whilst ignoring the batch dimension.
    • New filtered transformations: eqx.{filter_jvp,filter_custom_jvp}

    Bugfixes / backward incompatibilities

    • eqx.nn.GRUCell will now use its bias term. (Previously it was never adding this.)
    • eqx.filter_eval_shape will no longer promote array-likes to arrays, in either its input or its output.
    • eqx.tree_equal now treats JAX arrays and NumPy arrays as equal.

    Misc

    • Improved compilation speed of eqx.filter_vmap.

    New Contributors

    • @jondeaton made their first contribution in https://github.com/patrick-kidger/equinox/pull/204
    • @IsaacBreen made their first contribution in https://github.com/patrick-kidger/equinox/pull/215

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.8.0...v0.9.0

    Source code(tar.gz)
    Source code(zip)
  • v0.8.0(Sep 22, 2022)

    The ongoing march of small tweaks progresses.

    Main changes this release:

    • eqx.{is_array,is_inexact_array} now return True for np.ndarrays rather than False. This is technically a breaking change, hence the new minor version bump. Rationale in #202.
    • We now use jaxtyping. Hurrah!

    Other changes:

    • make sequential module immutable by @jenkspt in https://github.com/patrick-kidger/equinox/pull/195
    • Add support for asymmetric padding in Conv and ConvTransposed. by @Gurvan in https://github.com/patrick-kidger/equinox/pull/197

    New Contributors

    • @Gurvan made their first contribution in https://github.com/patrick-kidger/equinox/pull/197

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.7.1...v0.8.0

    Source code(tar.gz)
    Source code(zip)
  • v0.7.1(Sep 6, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Fixed NotImplementedError when computing gradients of stateful models by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/191
    • fix attention with mask and add tests by @uuirs in https://github.com/patrick-kidger/equinox/pull/190

    New Contributors

    • @uuirs made their first contribution in https://github.com/patrick-kidger/equinox/pull/190

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.7.0...v0.7.1

    Source code(tar.gz)
    Source code(zip)
  • v0.7.0(Aug 30, 2022)

    • Multiple bugfixes for differentiating through, and serialising, eqx.experimental.BatchNorm.
      • This is the reason for the version bump: if you are using eqx.experimental.{BatchNorm,SpectralNorm,StateIndex} then the serialisation format has changed.
    • Feature: use_ceil added to all pooling layers.

    Autogenerated release notes as follows:

    What's Changed

    • Add len and iter methods to nn.Sequential by @jenkspt in https://github.com/patrick-kidger/equinox/pull/174
    • Add attention functions and tests by @jenkspt in https://github.com/patrick-kidger/equinox/pull/181
    • Fixed BatchNorm not de/serialising correctly by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/172
    • Ordered tree map by @paganpasta in https://github.com/patrick-kidger/equinox/pull/170
    • added use_ceil to pooling by @paganpasta in https://github.com/patrick-kidger/equinox/pull/176
    • Dev by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/184

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.6.0...v0.7.0

    Source code(tar.gz)
    Source code(zip)
  • v0.6.0(Aug 3, 2022)

    • Refactor: the serialisation format for eqx.experimental.{BatchNorm,SpectralNorm,StateIndex} under eqx.tree_{de,}serialise_leaves has been tweaked slightly to avoid an edge-case crash. [This is the reason for the minor version bump to 0.6.0, as this is technically a (very minor) compatibility break.]
    • Refactor: changed from jax.tree_map to jax.tree_util.tree_map to remove all the deprecation warnings JAX has started giving.
    • Feature: added eqx.nn.Lambda (for use with eqx.nn.Sequential)
    • Feature: added eqx.default_{de,}serialise_filter_spec (for use `eqx.tree_{de,}serialise_leaves).
    • Bugfix: fixed BatchNorm crashing under jax.grad.
    • Documentation: lots of tidy-ups and improvements.

    Autogenerated release notes as follows:

    What's Changed

    • Doc tweak by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/141
    • Fix GroupNorm channels argument and docstring by @jenkspt in https://github.com/patrick-kidger/equinox/pull/148
    • make Sequential indexable and add tests by @jenkspt in https://github.com/patrick-kidger/equinox/pull/153
    • replace tree_* with tree_util.tree_* to avoid jax warning messages by @amir-saadat in https://github.com/patrick-kidger/equinox/pull/156
    • Extend deserial filter by @paganpasta in https://github.com/patrick-kidger/equinox/pull/145
    • added lambda_layer to composites by @paganpasta in https://github.com/patrick-kidger/equinox/pull/158
    • Tweaked docs for Lambda by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/159
    • Tweaked intro docs to improve readability of filtering by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/160
    • Batch norm grad crash fix by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/162
    • added #outputs to the StateIndex example by @paganpasta in https://github.com/patrick-kidger/equinox/pull/164
    • Fixed crash when serialising StateIndices without saved state by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/167
    • v0.6.0 by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/169

    New Contributors

    • @jenkspt made their first contribution in https://github.com/patrick-kidger/equinox/pull/148
    • @amir-saadat made their first contribution in https://github.com/patrick-kidger/equinox/pull/156

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.5.6...v0.6.0

    Source code(tar.gz)
    Source code(zip)
  • v0.5.6(Jul 20, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Adaptive avg pool 1d by @paganpasta in https://github.com/patrick-kidger/equinox/pull/129
    • {Avg,Max}Pool{1,2,3}D -> {Avg,Max}Pool{1,2,3}d. Removed wrong stride default. by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/135
    • Tweaked AdaptivePool by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/139
    • Adds adaptive pooling by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/140

    New Contributors

    • @paganpasta made their first contribution in https://github.com/patrick-kidger/equinox/pull/129

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.5.5...v0.5.6

    Source code(tar.gz)
    Source code(zip)
  • v0.5.5(Jul 20, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Fix doc typo by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/130
    • Updated pooling docs with init and call by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/131
    • Doc fix by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/132
    • Tidied helper into a relative import by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/133
    • minor bug fix by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/134

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.5.4...v0.5.5

    Source code(tar.gz)
    Source code(zip)
  • v0.5.4(Jul 5, 2022)

    • Feature: added equinox.filter_eval_shape.
    • Feature: equinox.filter_pmap(out=...) now supports callable arguments.
    • Upgrade: equinox.filter_vmap(out=...) now properly supports callable arguments. (Previously they were experimental. No part of the API has changed.)
    • Bugfix for passing PyTrees to equinox.filter_{vmap,pmap}(out=...) not working. (Thanks @jatentaki!)

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.5.3...v0.5.4

    Source code(tar.gz)
    Source code(zip)
  • v0.5.3(Jun 14, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Fix pmap named axes by @jatentaki in https://github.com/patrick-kidger/equinox/pull/113

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.5.2...v0.5.3

    Source code(tar.gz)
    Source code(zip)
  • v0.5.2(Jun 6, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Fixed GroupNorm raising a spurious runtime error by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/106

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.5.1...v0.5.2

    Source code(tar.gz)
    Source code(zip)
  • v0.5.1(Jun 6, 2022)

    This release:

    • Adds equinox.nn.GroupNorm.
    • Adds support for grouped convolutions and transposed convolutions, e.g. equinox.nn.Conv2d(..., groups=...). (Thanks @jatentaki!)
    • Fixes exceptions raised by tree_deserialise_leaves having no message.
    • Fixes a few documentation issues. (Thanks @jvmncs!)

    Autogenerated release notes as follows:

    What's Changed

    • Minor doc tweaks for filter_{vmap,pmap} by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/84
    • Updated examples to v0.5.0 by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/86
    • Doc fix for nn.Pool by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/89
    • fix two small typos in documentation by @jvmncs in https://github.com/patrick-kidger/equinox/pull/91
    • Fixed uninformative errors when deserialising by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/97
    • Implement num_feature_groups for Conv by @jatentaki in https://github.com/patrick-kidger/equinox/pull/100
    • Tweak docs by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/101
    • Added GroupNorm by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/104
    • Bump version by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/103

    New Contributors

    • @jvmncs made their first contribution in https://github.com/patrick-kidger/equinox/pull/91
    • @jatentaki made their first contribution in https://github.com/patrick-kidger/equinox/pull/100

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.5.0...v0.5.1

    Source code(tar.gz)
    Source code(zip)
  • v0.5.0(May 6, 2022)

    This is a big update.

    Exciting new features!

    • Added filter_vmap.

      • This can be used to create ensembles of models.
      • (Closes #65.)
    • Added filter_pmap.

      • (Closes #65.)
    • Added pooling layers:

      • eqx.nn.Pool
      • eqx.nn.AvgPool1d
      • eqx.nn.AvgPool2d
      • eqx.nn.AvgPool3d
      • eqx.nn.MaxPool1d
      • eqx.nn.MaxPool2d
      • eqx.nn.MaxPool3d
      • (Closes #59.)
      • (Thanks to @Benjamin-Walker for implementing this.)
    • Added tree_serialise_leaves and tree_deserialise_leaves.

      • This can be used to save and load models to file.
      • (Closes #46.)
      • (Thanks to @Jaschau for helpful discussions on this.)
    • Added tree_inference, as a convenience for toggling all inference flags through a model.

    Refactoring for nicer APIs

    • filter_{jit,grad,value_and_grad} now have an easier-to-use API for specifying which arguments have what behaviour.

      • Instead of having to specify (args, kwargs) as a single PyTree, then you can specify a default, args, kwargs separately. In particular this avoids doing messy stuff like filter_spec=((...), {}) when you had no kwargs.
      • You no longer have to match up the filter specification for args and kwargs against their runtime values. Both the runtime values, and the filter specification, are matched up against the function signature. e.g. you can do filter_jit(lambda x: x, kwargs=dict(x=True))(1), using a keyword argument at JIT-time and a positional argument at call time.
      • Currying is available: both filter_jit(fun) and filter_jit(default=...)(fun) will work.
      • The old API is still available for backward compatibility, of course.
      • (Closes #48.)
    • tree_at can now replace subtrees, and not just leaves.

      • (Closes #47.)
    • filter, partition now support an is_leaf argument.

      • (Closes #68.)

    Miscellaneous

    • Calling filter_jit(filter_grad(fun)) twice will no longer lead to unnecessary recompilation: the second filter_grad(fun) instance will be a PyTree that looks like the first filter_grad(fun) instance, and thus we won't get any recompilation.
      • This is actually an improvement over standard JAX! See https://github.com/google/jax/discussions/10284.

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.4.0...v0.5.0

    Source code(tar.gz)
    Source code(zip)
  • v0.4.0(Apr 8, 2022)

    A new minor release as there's a few minor breaking changes:

    • Some of the projections in MultiheadAttention no longer have a bias by default (#60)
    • equinox.experimental.{get_state,set_state,BatchNorm} now raises RuntimeErrors for many things; this is to match a change in how jax.experimental.host_callback raises errors in jaxlib>=0.3.5. (#63)

    Besides this, a couple of more exciting (?) things:

    • equinox.tree_pformat (which is used when printing equinox.Modules) now pretty-prints results much more neatly. (#62)
    • equinox.experimental.{get_state,BatchNorm,SpectralNorm} are now substantially faster when run in inference mode. (#61)

    Both of which sound pretty minor but both of which were technically really interesting to implement ;)


    The pull requests in this release were:

    • Linear doc tweaks by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/58
    • Remove default bias in MultiheadAttention by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/60
    • Improvements to stateful by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/61
    • Improved pretty-printing by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/62
    • Bump version by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/63

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.3.2...v0.4.0

    Source code(tar.gz)
    Source code(zip)
  • v0.3.2(Mar 31, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Version 0.3.2: SpectralNorm and more! by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/55

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.3.1...v0.3.2

    Source code(tar.gz)
    Source code(zip)
  • v0.3.1(Mar 28, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Fix function names when logging compilation progress by @marcelroed in https://github.com/patrick-kidger/equinox/pull/52

    New Contributors

    • @marcelroed made their first contribution in https://github.com/patrick-kidger/equinox/pull/52

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.3.0...v0.3.1

    Source code(tar.gz)
    Source code(zip)
  • v0.3.0(Mar 27, 2022)

    Three main things in this release.

    1. equinox.experimental.BatchNorm. Hurrah, that's nice to have.
    2. Very interesting from a technical point of view: stateful operations. In this case, equinox.experimental.{get_state, set_state, StateIndex}. These are the technology used to update the statistics of BatchNorm without requiring the user to faff around outputting the model themselves. They work by wrapping jax.experimental.host_callback.call to save and load external state on demand. Which is pretty magic, so these should really be used sparingly...
    3. Removed several old pieces of deprecated functionality: equinox.jitf and so on.

    Autogenerated release notes as follows:

    What's Changed

    • Score based example by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/42
    • minor doc fixes by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/43
    • Version 0.3.0 -- BatchNorm and stateful by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/50

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.2.2...v0.3.0

    Source code(tar.gz)
    Source code(zip)
  • v0.2.2(Mar 15, 2022)

    Added several new layers:

    • LayerNorm
    • MultiheadAttention
    • ConvTranspose, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
    • Embedding

    (Thanks to @andyehrenberg for implementing much of this, and to @lucidrains for reviewing the implementation of attention.)

    Autogenerated release notes as follows:

    What's Changed

    • ConvTranspose layers, MultiheadAttention, lookup embeddings, LayerNorm by @andyehrenberg in https://github.com/patrick-kidger/equinox/pull/34
    • Tidied; simplified; generalised ConvTranspose implementation. by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/40
    • Attention, Transposed Convolutions, Embeddings, LayerNorm by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/38

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.2.1...v0.2.2

    Source code(tar.gz)
    Source code(zip)
  • v0.2.1(Mar 11, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Added automated releases. by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/39
    • Added pretty-printing for Modules by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/41

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.2.0...v0.2.1

    Source code(tar.gz)
    Source code(zip)
  • v0.2.0(Mar 6, 2022)

    First release using GitHub releases. We'll be using this to serve as a changelog.

    This bumps the minor version 0.1.6 -> 0.2.0 so this is a breaking release. (Admittedly for something pretty minor. See the autogenerated changelog below.)

    What's Changed

    • Fixed filter_grad(has_aux=True) returning arguments in the wrong order. by @patrick-kidger in https://github.com/patrick-kidger/equinox/pull/36

    Full Changelog: https://github.com/patrick-kidger/equinox/compare/v0.1.6...v0.2.0

    Source code(tar.gz)
    Source code(zip)
Owner
Patrick Kidger
Maths+ML PhD student at Oxford. Neural ODEs+SDEs+CDEs, time series, rough analysis. (Also ice skating, martial arts and scuba diving!)
Patrick Kidger
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

JAX: Autograd and XLA Quickstart | Transformations | Install guide | Neural net libraries | Change logs | Reference docs | Code search News: JAX tops

Google 21.3k Jan 1, 2023
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

JAX: Autograd and XLA Quickstart | Transformations | Install guide | Neural net libraries | Change logs | Reference docs | Code search News: JAX tops

Google 11.4k Feb 13, 2021
pybaum provides tools to work with pytrees which is a concept burrowed from JAX.

pybaum provides tools to work with pytrees which is a concept burrowed from JAX.

Open Source Economics 9 May 11, 2022
Composable transformations of Python+NumPy programsComposable transformations of Python+NumPy programs

Chex Chex is a library of utilities for helping to write reliable JAX code. This includes utils to help: Instrument your code (e.g. assertions) Debug

DeepMind 506 Jan 8, 2023
Official codebase for running the small, filtered-data GLIDE model from GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models.

GLIDE This is the official codebase for running the small, filtered-data GLIDE model from GLIDE: Towards Photorealistic Image Generation and Editing w

OpenAI 2.9k Jan 4, 2023
Jittor is a high-performance deep learning framework based on JIT compiling and meta-operators.

Jittor: a Just-in-time(JIT) deep learning framework Quickstart | Install | Tutorial | Chinese Jittor is a high-performance deep learning framework bas

null 2.7k Jan 3, 2023
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022
Pytorch implementation of "Grad-TTS: A Diffusion Probabilistic Model for Text-to-Speech"

GradTTS Unofficial Pytorch implementation of "Grad-TTS: A Diffusion Probabilistic Model for Text-to-Speech" (arxiv) About this repo This is an unoffic

HeyangXue1997 103 Dec 23, 2022
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.

Rockpool Rockpool is a Python package for developing signal processing applications with spiking neural networks. Rockpool allows you to build network

SynSense 21 Dec 14, 2022
We present a framework for training multi-modal deep learning models on unlabelled video data by forcing the network to learn invariances to transformations applied to both the audio and video streams.

Multi-Modal Self-Supervision using GDT and StiCa This is an official pytorch implementation of papers: Multi-modal Self-Supervision from Generalized D

Facebook Research 42 Dec 9, 2022
Image transformations designed for Scene Text Recognition (STR) data augmentation. Published at ICCV 2021 Workshop on Interactive Labeling and Data Augmentation for Vision.

Data Augmentation for Scene Text Recognition (ICCV 2021 Workshop) (Pronounced as "strog") Paper Arxiv Why it matters? Scene Text Recognition (STR) req

Rowel Atienza 152 Dec 28, 2022
Using some basic methods to show linkages and transformations of robotic arms

roboticArmVisualizer Python GUI application to create custom linkages and adjust joint angles. In the future, I plan to add 2d inverse kinematics solv

Sandesh Banskota 1 Nov 19, 2021
Code for "Learning Structural Edits via Incremental Tree Transformations" (ICLR'21)

Learning Structural Edits via Incremental Tree Transformations Code for "Learning Structural Edits via Incremental Tree Transformations" (ICLR'21) 1.

NeuLab 40 Dec 23, 2022
Canonical Appearance Transformations

CAT-Net: Learning Canonical Appearance Transformations Code to accompany our paper "How to Train a CAT: Learning Canonical Appearance Transformations

STARS Laboratory 54 Dec 24, 2022
Time-stretch audio clips quickly with PyTorch (CUDA supported)! Additional utilities for searching efficient transformations are included.

Time-stretch audio clips quickly with PyTorch (CUDA supported)! Additional utilities for searching efficient transformations are included.

Kento Nishi 22 Jul 7, 2022
An official PyTorch implementation of the TKDE paper "Self-Supervised Graph Representation Learning via Topology Transformations".

Self-Supervised Graph Representation Learning via Topology Transformations This repository is the official PyTorch implementation of the following pap

Hsiang Gao 2 Oct 31, 2022