Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

Overview

Diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

Diffrax is a JAX-based library providing numerical differential equation solvers.

Features include:

  • ODE/SDE/CDE (ordinary/stochastic/controlled) solvers;
  • lots of different solvers (including Tsit5, Dopri8, symplectic solvers, implicit solvers);
  • vmappable everything (including the region of integration);
  • using a PyTree as the state;
  • dense solutions;
  • multiple adjoint methods for backpropagation;
  • support for neural differential equations.

From a technical point of view, the internal structure of the library is pretty cool -- all kinds of equations (ODEs, SDEs, CDEs) are solved in a unified way (rather than being treated separately), producing a small tightly-written library.

Installation

pip install diffrax

Requires Python >=3.7 and JAX >=0.2.27.

Documentation

Available at https://docs.kidger.site/diffrax.

Quick example

from diffrax import diffeqsolve, ODETerm, Dopri5
import jax.numpy as jnp

def f(t, y, args):
    return -y

term = ODETerm(f)
solver = Dopri5()
y0 = jnp.array([2., 3.])
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)

Here, Dopri5 refers to the Dormand--Prince 5(4) numerical differential equation solver, which is a standard choice for many problems.

Citation

If you found this library useful in academic research, please cite: (arXiv link)

@phdthesis{kidger2021on,
    title={{O}n {N}eural {D}ifferential {E}quations},
    author={Patrick Kidger},
    year={2021},
    school={University of Oxford},
}

(Also consider starring the project on GitHub.)

Comments
  • [WIP] Delay differential equations

    [WIP] Delay differential equations

    @thibmonsel

    This is a quick WIP draft of how we might add support for delay diffeqs into Diffrax.

    The goal is to make the API follow:

    def vector_field(t, y, args, *, history):
        ...
    
    delays = [lambda t, y, args: 1.0,
              lambda t, y, args: max(y, 1)]
    
    diffeqsolve(ODETerm(vector_field), ..., delays=delays)
    

    There's several pieces that still need doing:

    • The nonlinear solve, with respect to the dense solution over each step. (E.g. as per Section 4.1 of the DelayDiffEq.jl paper)
    • Detecting discontinuities and stepping to them directly. (Section 4.2)
    • Possibly add special support for "nice" delays, that we might be able to handle more efficiently? E.g. as long as our minimal delay is larger than our step size then the nonlinear solve can be skipped.
    • Adding documentation.
    • Adding an example.
    • Probably now would be a good time to figure out how to add support for solving DAEs as well (e.g. see #62). Both involve a nonlinear solve, and both involve passing extra information to the user-provided vector field. It might be that we can make use the same mechanisms for both. (And at the very least we should ensure that any choices we make now don't negatively impact DAE support later.)
    opened by patrick-kidger 24
  • Can't return solution of coupled differential equations

    Can't return solution of coupled differential equations

    I'm trying to solve a mid-sized system of coupled differential equations with diffrax. I'm using version 0.2.0. Here's a short snippet of dummy code that raises the issue I'm having:

    import jax.numpy as jnp
    from diffrax import diffeqsolve, ODETerm, Kvaerno3,PIDController
    
    def Results():
        def Y_prime(t, Y, args):
            dY = jnp.array([Y[6], (Y[5]-Y[6])**2,Y[0]+Y[7], (Y[1])**2, Y[2],Y[3], Y[4]**3, Y[5]**2])
            return dY
            
        t_init = 100
        t_fin = 1e5
    
        Yn_i = 1e-5
        Yp_i = 1e-6
        Yd_i = 1e-12
        Yt_i = 1e-12
        YHe3_i = 1e-12
        Ya_i = 1e-12
        YLi7_i = 1e-12
        YBe7_i = 1e-12
    
        Y0=jnp.array([[Yn_i], [Yp_i], [Yd_i], [Yt_i], [YHe3_i], [Ya_i], [YLi7_i], [YBe7_i]])
        term = ODETerm(Y_prime)
        solver = Kvaerno3()
        stepsize_controller = PIDController(rtol=1e-8, atol=1e-8)
        t_eval = jnp.logspace(jnp.log10(t_init),jnp.log10(t_fin),num=100)
        sol_at_MT = diffeqsolve(term, solver, t0=jnp.float64(t_init), t1=jnp.float64(t_fin), dt0=jnp.float64((t_eval[1]-t_eval[0])/10),y0=Y0,stepsize_controller=stepsize_controller,max_steps=None)
        Yn_MT_f, Yp_MT_f, Yd_MT_f, Yt_MT_f, YHe3_MT_f, Ya_MT_f, YLi7_MT_f, YBe7_MT_f = sol_at_MT.ys[-1][0][0],sol_at_MT.ys[-1][1][0],sol_at_MT.ys[-1][2][0],sol_at_MT.ys[-1][3][0],sol_at_MT.ys[-1][4][0],sol_at_MT.ys[-1][5][0],sol_at_MT.ys[-1][6][0],sol_at_MT.ys[-1][7][0]
    
        Yn_f,Yp_f,Yd_f,Yt_f,YHe3_f,Ya_f,YLi7_f,YBe7_f = Yn_MT_f, Yp_MT_f, Yd_MT_f,Yt_MT_f,YHe3_MT_f,Ya_MT_f,YLi7_MT_f, YBe7_MT_f
        return jnp.array([Yn_f,Yp_f,Yd_f,Yt_f,YHe3_f,Ya_f,YLi7_f,YBe7_f])
    Yn_f,Yp_f,Yd_f,Yt_f,YHe3_f,Ya_f,YLi7_f,YBe7_f = Results()
    print(Yn_f)
    

    It seems diffrax successfully solves the differential equation, but struggles to return the output, i.e. it seems the code hangs when trying to assign values to the variable sol_at_MT. Tampering a bit with the diffrax source, it looks like there are two things going on.

    One is that, no matter what I try to return (even if I set all of the returns to None), if the lines right before the return in integrate.py

    branched_error_if(
        throw & jnp.invert(is_okay(result)),
        error_index,
        RESULTS.reverse_lookup,
    )
    

    aren't commented out, the code will freeze. I can include a print statement right after these lines (just before the return) that prints out successfully even when they're not commented, but I can't assign anything to sol_at_MT in without the code hanging if these lines are left in.

    Then, if I comment that branched_error_if() call out, the code still hangs if I try to return ts, ys, stats or result from integrate.py. This doesn't seem to be an issue of time or memory; the code just freezes up and can't even be aborted from the command line whether I'm running locally or with extra resources on a cluster.

    question 
    opened by cgiovanetti 12
  • Handling discontinuities in time derivative?

    Handling discontinuities in time derivative?

    Hi, first of all, let me say that this looks like an amazing project. I am looking forward to playing around with this :).

    In a concrete problem I am dealing with, I have a forced system where the external force is piecewise constant. The external force changes at specific time points (t1, ..., tn), causing a discontinuity of the time derivative.
    I would like to use adaptive step-size solvers for increased accuracy, but naively applying adaptive step-size solvers will "waste" a lot of steps to find the point of change.

    Would including the change points in SaveAt avoid this problem? Or is there some other recommended way to handle this?

    opened by jaschau 12
  • Slow `jit` compilation time compared to `jax.experimental.ode.odeint`

    Slow `jit` compilation time compared to `jax.experimental.ode.odeint`

    hi @patrick-kidger, big fan of diffrax!

    I've been playing around with some of the functionality you have in this repository and comparing it with the ode-solver in jax. The one pain point i noticed is that there seems to be a relatively slow jit compilation time, particularly when I try to jit the grad of a simple loss function containing diffeqsolve. I was wondering if this is an error on my part (perhaps I botched the diffrax implementation) or if there is yet to be some optimization. The demonstration is below:

    from jax.config import config
    config.update("jax_enable_x64", True)
    config.update("jax_debug_nans", True) 
    config.parse_flags_with_absl()
    import jax
    import jax.numpy as jnp
    from jax import random
    import numpy as np
    from functools import partial
    import haiku as hk
    
    def exact_kinematic_aug_diff_f(t, y, args_tuple):
        """
        """
        _y, _, _ = y
        _params, _key, diff_f = args_tuple
        aug_diff_fn = lambda __y : diff_f(t, __y, (_params,))
        _f, s, t = aug_diff_fn(_y)
        r = jnp.sum(t)
        return _f, r, 0.
    
    def exact_kinematic_odeint_diff_f(y, t, params, canonical_diff_fn):
        run_y = y[0]
        _f, s, t = canonical_diff_fn(t, run_y, (params,))
        return _f, jnp.sum(s), 0.
    
    class TestMLP(hk.Module):
        def __init__(self, num_particles, name=None):
            super().__init__(name=None)
            self._mlp = hk.nets.MLP([8,8,8,8,num_particles*12])
            self._num_particles=num_particles
        def __call__(self, t, y):
            in_y = (y + t).flatten()
            outter = self._mlp(in_y).reshape((4, self._num_particles, 3))
            return outter[:2], outter[2], outter[3]
    
    def test(num_particles):
        import functools
        from jax.experimental.ode import odeint
        import diffrax
        
        #generate positions/velocities
        small_positions = jax.random.normal(jax.random.PRNGKey(261), shape=(num_particles,3))
        small_velocities = jax.random.normal(jax.random.PRNGKey(235), shape=(num_particles,3))
        small_positions_and_velocities = jnp.vstack([small_positions[jnp.newaxis, ...], small_velocities[jnp.newaxis, ...]])
        
        # make module kwargs
        VectorMLP_kwargs = {'num_particles': num_particles}
        
        # make module function
        def _diff_f_wrapper(t, y):
            diff_f = TestMLP(**VectorMLP_kwargs)
            return diff_f(t, y)
        
        diff_f_init, diff_f_apply = hk.without_apply_rng(hk.transform(_diff_f_wrapper))
        init_params = diff_f_init(jax.random.PRNGKey(36), 0., small_positions_and_velocities)
        canonicalized_diff_f_fn = lambda _t, _y, _args_tuple : diff_f_apply(_args_tuple[0], _t, _y)
        
        # make the augmented functions
        odeint_aug_diff_func = functools.partial(exact_kinematic_odeint_diff_f, canonical_diff_fn=canonicalized_diff_f_fn)
        diffeqsolve_aug_diff_func = exact_kinematic_aug_diff_f
        
        # odeint solver
        def odeint_solver(_parameters, _init_y, _key):
            aug_init_y = (_init_y, 0., 0.)
            outs = odeint(odeint_aug_diff_func, aug_init_y, jnp.array([0., 1.]), _parameters, rtol=1.4e-8, atol=1.4e-8)
            final_outs = (outs[0][-1], outs[1][-1], outs[2][-1])
            return final_outs
        
        def diffrax_ode_solver(_parameters, _init_y, _key):
            term=diffrax.ODETerm(diffeqsolve_aug_diff_func)
            stepsize_controller=diffrax.PIDController(rtol=1.4e-8, atol=1.4e-8)
            solver = diffrax.Dopri5()
            aug_init_y = (_init_y, 0., 0.)
            sol = diffrax.diffeqsolve(term, 
                                      solver, 
                                      t0=0., 
                                      t1=1., 
                                      dt0=1e-1, 
                                      y0=aug_init_y, 
                                      stepsize_controller=stepsize_controller, 
                                      args=(_parameters, _key, canonicalized_diff_f_fn))
            return sol.ys[0][0], sol.ys[1][0], sol.ys[2][0]
        
        @jax.jit
        def odeint_loss_fn(_params, _init_y, _key):
            ode_solution = odeint_solver(_params, _init_y, _key)
            return jnp.sum(ode_solution[1]**2)
        
        @jax.jit
        def diffrax_loss_fn(_params, _init_y, _key):
            ode_solution = diffrax_ode_solver(_params, _init_y, _key)
            return jnp.sum((ode_solution[1])**2)
        
        # test
        import time
        
        # odeint compilation time
        start_time = time.time()
        _ = jax.grad(odeint_loss_fn)(init_params, small_positions_and_velocities, jax.random.PRNGKey(34))
        end_time = time.time()
        print(f"odeint comp. time: {end_time - start_time}")
        
        # diffrax compilation time
        start_time = time.time()
        _ = jax.grad(diffrax_loss_fn)(init_params, small_positions_and_velocities, jax.random.PRNGKey(34))
        end_time = time.time()
        print(f"diffrax comp. time: {end_time - start_time}")
    
    

    running test(8) gives me the following compilation time on CPU:

    odeint comp. time: 2.5580570697784424
    diffrax comp. time: 23.965799570083618
    

    I noticed that if I use diffrax.BacksolveAdjoint, compilation time goes down to ~8 seconds, but I'm keen to avoid that method based on your docs.; also, it looks like the compilation time in diffrax is heavily dependent on the number of hidden layers in TestMLP, perhaps suggesting a non-optimal compilation in diffrax of for loops? Thanks!

    refactor next 
    opened by dominicrufa 11
  • No GPU/TPU found, falling back to CPU

    No GPU/TPU found, falling back to CPU

    Here's the full warning that I get (I do have a GPU):

    >>> import diffrax
    2022-03-24 16:30:19.350737: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:170] XLA service 0x55795c0d4670 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
    2022-03-24 16:30:19.350761: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:178]   StreamExecutor device (0): Interpreter, <undefined>
    2022-03-24 16:30:19.353414: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:169] TfrtCpuClient created.
    2022-03-24 16:30:19.353886: I external/org_tensorflow/tensorflow/stream_executor/tpu/tpu_platform_interface.cc:74] No TPU platform found.
    WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
    

    Edit: I installed diffrax from conda-forge.

    opened by ma-sadeghi 11
  • Logging metrics during an ODE solve

    Logging metrics during an ODE solve

    Hello @patrick-kidger,

    thank you for open-sourcing this nice library! I was going to resume work on my own small ODE lib, but since this is much more elaborate than what I came up with so far, I am inclined to use this instead for a small project in the future.

    One question that came up to me when reading the source code: Is there currently a way to compute step-wise metrics during the solve? (Think logging step sizes, Jacobian eigenvalues, etc.)

    This would presumably happen in the integrate method. Could I e.g. use the solver_state pytree for this in, say, overridden solver classes? Thank you for your consideration.

    opened by nicholasjng 11
  • Brownian motion classes accept pytrees for shape and dtype arguments

    Brownian motion classes accept pytrees for shape and dtype arguments

    This changes the argument shape for classes VirtualBrownianTree and UnsafeBrownianPath, and adds an additional argument dtype as per the dicussion in #180.

    • I decided upon shape: Pytree[Tuple[int, ...] instead of shape: Union[Tuple[int, ...], PyTree[jax.ShapeDtypeStruct]]. It's unclear what to do with named_shape in jax.ShapeDtypeStruct -- I don't know if there is a way to sample Brownian motion via named shapes. But if you feel strongly about this and give me some pointers, I can reimplement.
    • To allow specifying dtypes, dtype argument specifies them as a pytree and has to be a prefix tree of shape.
    • I added __init__ methods to both classes since I was not sure how to have dtype=None without it.
    • Added some helper functions that I use in misc.py, hope that's the right location to place them.
    • Used jtu.tree_map instead of jax.vmap -- was not sure how to supply is_leaf to jax.vmap. Happy to change this as well, with some pointers.
    • Changed the test_brownian.py:test_shape to test pytree shapes and dtypes. Just noticed that formatting made it look pretty bad, not sure if that's a big deal.
    • Tests pass locally.

    Let me know what you think. Thanks!

    opened by ciupakabra 9
  • added new kalman-filter example

    added new kalman-filter example

    I wrote a little additional example that showcases diffrax in a maybe not so obvious way. It also showcases equinox and the ability to freeze parameters. Let me know what you think (and what needs to be changed). Greetings

    opened by SimiPixel 8
  • Performance against `jax.experimental.ode.odeint`

    Performance against `jax.experimental.ode.odeint`

    Hi @patrick-kidger, I was excited to test out Diffrax in our code. However, we found it did not perform as well as expected. This is likely to nuances on our end, but because o https://github.com/google/jax/issues/9654 I thought I would post a MWE.

    import diffrax
    import jax
    import ticktack
    
    PARAMS = (774.86, 0.25, 0.8, 6.44)
    
    STEADY_PROD = 1.8803862513018528
    
    STEADY_STATE = jax.numpy.array(
        [1.34432991e+02, 7.07000000e+02, 1.18701144e+03,
        3.95666872e+00, 4.49574232e+04, 1.55056740e+02,
        6.32017337e+02, 4.22182768e+02, 1.80125397e+03,
        6.63307283e+02, 7.28080320e+03], 
        dtype=jax.numpy.float64)
    
    PROD_COEFFS = jax.numpy.array(
        [0.7, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 
        dtype=jax.numpy.float64)
    
    MATRIX = jax.numpy.array([
        [-0.509, 0.009, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.508, -0.44, 0.068, 0.0, 0.0, 0.545, 0.0, 0.167, 0.002, 0.002, 0.0],
        [0.0, 0.121, -0.155, 12.0, 0.001, 0.0, 0.0, 0.003, 0.0, 0.0, 0.0],
        [0.0, 0.0, 4.4000e-02, -1.3333e+01, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.042, 1.333, -0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.229, 0.0, 0.0, 0.0, -1.046, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.136, -0.033, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.364, 0.033, -0.183, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01, -0.002, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003, 0.0, -0.002, 0.0],
        [0.0, 0.0, 3.333e-04, 0.0, 5.291e-06, 0.0, 0.0, 0.0, 0.0, 4.0e-04, -1.2340e-04]], 
        dtype=jax.numpy.float64)
    
    @jax.jit 
    def driving_term(t, args):
        start_time, duration, phase, area = jax.numpy.array(args)
        middle = start_time + duration / 2.
        height = area / duration
    
        gauss = height * \
            jax.numpy.exp(- ((t - middle) / (0.5 * duration)) ** 16.)
        sine = STEADY_PROD + 0.18 * STEADY_PROD *\
            jax.numpy.sin(2 * jax.numpy.pi / 11 * t + phase * 2 * jax.numpy.pi / 11)
    
        return (sine + gauss) * 3.747
    
    @jax.jit
    def jax_dydt(y, t, args, /, matrix=MATRIX, production=driving_term, 
                       prod_coeffs=PROD_COEFFS):
        ans = jax.numpy.matmul(matrix, y)
        production_rate_constant = production(t, args)
        production_term = prod_coeffs * production_rate_constant
        return ans + production_term
    
    @jax.jit
    def diffrax_dydt(t, y, args, /, matrix=MATRIX, production=driving_term, 
                     prod_coeffs=PROD_COEFFS):
        ans = jax.numpy.matmul(matrix, y)
        production_rate_constant = production(t, args)
        production_term = prod_coeffs * production_rate_constant
        return ans + production_term
    
    time_out = jax.numpy.linspace(750, 800, 1000)
    
    %%timeit
    jax.experimental.ode.odeint(jax_dydt, STEADY_STATE, time_out, PARAMS)
    
    term = diffrax.ODETerm(diffrax_dydt)
    solver = diffrax.Bosh3()
    step_size = diffrax.PIDController(rtol=1e-10, atol=1e-10)
    save_time = diffrax.SaveAt(ts=time_out)
    
    %%timeit
    diffrax.diffeqsolve(args=PARAMS, terms=term, solver=solver, y0=STEADY_STATE,
                        t0=time_out.min(), t1=time_out.max(), dt0=0.01,
                        saveat=save_time, stepsize_controller=step_size, 
                        max_steps=10000)
    

    Sorry that the example is so volumous but I wanted to keep it very similar to our code.

    Thanks in advance.

    Jordan

    opened by Jordan-Dennis 8
  • Weird behaviour due to defaults when using Implicit-Euler

    Weird behaviour due to defaults when using Implicit-Euler

    When using dfx.ImplicitEuler() with everything set to default an error is raised

    missing rtol and atol of NewtonNonlinearSolver

    You are then prompted to set these values in the stepsize-controller, because it is by default supposed to fallback to the values provided in PIDController. But dfx.ImplicitEuler() does not support adaptive step-sizing using a PIDController.

    The solution is to use

    solver=dfx.ImplicitEuler(nonlinear_solver=dfx.NewtonNonlinearSolver(rtol=1e-3, atol=1e-6))
    

    Just something that feels a bit odd.

    refactor 
    opened by SimiPixel 6
  • Transform Feedforward-Network + solver into a Recurrent-Network

    Transform Feedforward-Network + solver into a Recurrent-Network

    Hello Patrick,

    let me first quickly motivate my feature request. As a side-project i am currently working on Model-based optimal control. For e.g. a only partially-observable environment stateful agents are useful. So, suppose the action selection of an agent is given by the following method

    def select_action(params, state, observation, time):
        apply = neural_network.apply
        state, action = apply(params, state, observation, time)
        return state, action
    
    while True:
        action = select_action(..., observation, env.time)
        observation = env.step(action)
    

    Typically, the apply-function is some recurrent neural network. Suppose the environment env is differentiable, because it is just some model of the environment (maybe another network). Now, i would like to replace the recurrent neural network with a feedforward network + solver without changing the API of the agent.

    I was wondering if constructing the following is possible and sensible? I.e. i would like to transform a choice of Feedforward-Network + Solver into a Recurrent-Network.

    def select_action(params, ode_state, observation, time):
        rhs = lambda x,u: neural_network.apply(params, x, u)
        solution, ode_state = odeint(ode_state, rhs, t1=time, u=(observation, time))
        return ode_state, solution.x(time)
    

    I would like to emphasis that this select_action must remain differentiable: The x-output w.r.t the network parameters.

    I would love to hear your input :) Anyways thank you in advance.

    opened by SimiPixel 5
  • ODE solver fails with 'The maximum number of solver steps was reached. Try increasing `max_steps`'

    ODE solver fails with 'The maximum number of solver steps was reached. Try increasing `max_steps`'

    Hi,

    I was playing with this cool package on a chemical reaction ODE problem. This problem solves the time evolution of seven chemical concentrations, which is a stiff problem but can be solved using a Fortran-based solver. However, the diffrax version fails, with an XlaRuntimeError complaining 'The maximum number of solver steps was reached. Try increasing max_steps'. Unfortunately, the error persists no matter how large the max_steps is and which solver is used (e.g., impliciteuler or Kvaerno5). Note that when commenting the error message in diffeqsolve function, I find that the code can solve about the first 100s and output inf (from solution.ys) in a later time.

    Any suggestion would be appreciated!

    Below is the code snippet --

    from diffrax import diffeqsolve, ODETerm, SaveAt
    from diffrax import NewtonNonlinearSolver, Dopri5, Kvaerno3, ImplicitEuler, Euler, Kvaerno5
    from diffrax import PIDController
    
    import jax
    import jax.numpy as jnp
    import jax.random as jrandom
    
    from jax.config import config
    config.update("jax_enable_x64", True)
    
    def funclog2(t, logy, args):
        k1, k2, k3 = args[0], args[1], args[2]
        kd1, kd2, kd3 = args[3], args[4], args[5]
        ka1, ka2, ka3 = args[6], args[7], args[8]
        r4 = args[9]
        
        y = jnp.power(10, logy)
        doc, o2, no3, no2, n2, co2, bm = y
        
        # log transform scale
        scale = 1 / jnp.log(10)
        scale = scale / y
        
        # The stoichiometry matrix
        stoich = jnp.array([
            [-1, -1, -1, 5],
            [0, 0, -1, 0],
            [-2, 0, 0, 0],
            [1, -1, 0, 0],
            [0, 1, 0, 0],
            [1, 1, 1, 0],
            [0, 0, 0, -1]
        ])
        
        # Scale stoich
        stoich = jax.vmap(lambda a, b: a*b, in_axes=0)(scale, stoich)
        
        # Reaction rate
        r1 = k1 * bm * doc/(doc+kd1) * no3/(no3+ka1)
        r2 = k2 * bm * doc/(doc+kd2) * no2/(no2+ka2)
        r3 = k3 * bm * doc/(doc+kd3) * o2/(no2+ka3)
        
        r = jnp.array([r1, r2, r3, r4]).T
        
        return stoich @ r
    
    # Static parameters
    k1, k2, k3 = 3.24e-4, 2.69e-4, 9e-4 # [mol/L/sec/mass [BM]]
    kd1, kd2, kd3 = 2.5e-4, 2.5e-4, 2.5e-4 # [mol/L]
    ka1, ka2, ka3 = 1e-6, 4e-6, 1e-6  # [mol/L]
    r4 = 2.8e-6 # [mol/L/sec]
    args = jnp.array([k1, k2, k3, kd1, kd2, kd3, ka1, ka2, ka3, r4])
    
    # The initial concentrations with the following order [mol/L]:
    # doc, o2, no3, no2, n2, co2, bm
    # y0 = jnp.array([4.16e-05, 0.000266, 0.000396, 1e-10, 1e-10, 0.00248, 0.0003])
    y0 = jnp.array([4.16e-05, 0.000266, 0.000396, 1e-3, 1e-3, 0.00248, 0.0003])
    logy0 = jnp.log10(y0)
    
    term = ODETerm(funclog2)
    # solver = Dopri5()
    # solver = Euler()
    solver = Kvaerno5(NewtonNonlinearSolver(rtol=1e-3, atol=1e-6))
    # solver = ImplicitEuler(NewtonNonlinearSolver(rtol=1e-3, atol=1e-6))
    # t0, t1, dt0 = 0, 3600*24*30, 1
    t0, t1, dt0 = 0, 200, 0.01
    # t0, t1, dt0 = 0, 3600*24, 3600
    solution = diffeqsolve(term, solver, t0=t0, t1=t1, dt0=dt0, max_steps=400000,
                           stepsize_controller=PIDController(rtol=1e-3, atol=1e-6),
                           saveat = SaveAt(t0=True, ts=jnp.linspace(t0,t1)), 
                           y0=logy0, args=args)
    solution.stats
    
    question 
    opened by PeishiJiang 5
  • Truncated Back Propagation through time

    Truncated Back Propagation through time

    Hi, I was wondering if it possible to integrate truncated back propagation through time (TBPTT) into Diffrax. I couldn't find any options for this in Diffrax or Equinox, nor could I find any implementation of TBPTT in the source code in integrate.py, but maybe I missed it. My best guess would be to write a custom adjoint class that would implement TBPTT, but I am not sure how to do this. My question is: would it be possible to (easily) implement TBPTT to train my NDEs and how should I approach this?

    feature 
    opened by sdevries0 1
  • Fastest way to evaluate a solution

    Fastest way to evaluate a solution

    Hi, suppose I have a simple ODE that I solve with diffrax. What would be the fastest way to use the solution in another piece of code? I need to evaluate the solution on some points not known in advance, and I thought of generating a dense solution sol and then use its method evaluate on the points of interest, i.e. every time I need it, call sol.evaluate() on my points of interest (using vmap when needed). Is this the most efficient way, or shall I interpolate myself a fixed grid solution and create a jitted function that evaluates it on my points of interest?

    question 
    opened by marcofrancis 1
  • Make diffeqsolve convertable to TensorFlow

    Make diffeqsolve convertable to TensorFlow

    Based on a talk on NODE's on youtube I came across this package, and this looks perfect for some project we are planning (thanks for the great talk!) . Now one of the platforms where we want to run our code does not support JAX/XLA/Tensorflow. Just ONNX. I tried converting a simulation function to Tensorflow for later conversion to ONNX, but this fails because the unsupported unvmap_any is used (at compiletime!) to deduce the amount of iterations needed.

    Minimal example:

    import tensorflow as tf
    import jax.numpy as jnp
    import tf2onnx
    
    from diffrax import diffeqsolve, ODETerm, Euler
    from jax.experimental import jax2tf
    
    def simulate(y0):
        solution = diffeqsolve(
                terms=ODETerm(lambda t, y, a: -y), solver=Euler(),
                t0=0, t1=1, dt0=0.1, y0=y0)
        return solution.ys[0]
    
    # This works
    x = simulate(100)
    assert jnp.isclose(x, jnp.exp(-1)*100, atol=.1, rtol=.1)
    
    simulate_tf = tf.function(jax2tf.convert(simulate, enable_xla=False))
    
    # Does not work:
    # simulate_tf(100)
    # => NotImplementedError: TensorFlow interpretation rule for 'unvmap_any' not implemented
    
    # Also doesn't not work:
    tf2onnx.convert.from_function(
            simulate_tf, input_signature=[tf.TensorSpec((), tf.float32)])
    # simulate_tf(100)
    # => NotImplementedError: TensorFlow interpretation rule for 'unvmap_any' not implemented
    

    For us, it would be really nice to use a GPU/TSP during training with jax, then transfer to this specifc piece of hardware with just ONNX support for inference (at this point I don't need gradient calculation anymore). Of course, solving this might be completely outside the scope of the project and there are other solutions like writing the solvers from scratch or using existing solvers in TF/PyTorch.

    Currently my knowledge of JAX is limited (hopefully this will soon improve!). If this is the only function stopping Diffrax from being tensorflow-convertable maybe a small workaround could be possible. I'm also happy with a answer like 'no we don't do this' or 'send us a PR if you want to have this fixed'

    feature 
    opened by llandsmeer 6
  • Question about BacksolveAdjoint through SemiImplicitEuler solver

    Question about BacksolveAdjoint through SemiImplicitEuler solver

    I am testing the adjoint method to calculate the gradients from a SemiImplicitEuler solver. I met errors when calculate the gradients using BacksolveAdjoint method. Here is a working example. It would be great to have some suggestions.

    Thank you in advance!

    ` from diffrax import diffeqsolve, ODETerm, SemiImplicitEuler, SaveAt, BacksolveAdjoint import jax.numpy as jnp from jax import grad from matplotlib import pyplot as plt

    def drdt(t, v, args): return v

    def dvdt(t, r, args): return -args[0]*(r-args[1])

    terms =(ODETerm(drdt),ODETerm(dvdt)) solver = SemiImplicitEuler() y0 = (jnp.array([1.0]),jnp.array([0.0])) saveat = SaveAt(ts=jnp.arange(0,30,0.1))

    def loss(y0): solution = diffeqsolve(terms, solver, t0=0, t1=30, dt0=0.0001, y0=y0, args=[1.0,0.0], saveat=saveat,max_steps=10000000,adjoint=BacksolveAdjoint()) return jnp.sum(solution.ys[0]) grads = grad(loss)(y0) print(grads) `

    here is the error message:

    Traceback (most recent call last): File "test_harmonic.py", line 23, in <module> grads = grad(loss)(y0) File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/grad.py", line 482, in fn_bwd_wrapped out = fn_bwd(residuals, grad_diff_array_out, vjp_arg, *args, **kwargs) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/adjoint.py", line 394, in _loop_backsolve_bwd state, _ = _scan_fun(state, val0, first=True) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/adjoint.py", line 332, in _scan_fun _sol = diffeqsolve( File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/jit.py", line 82, in __call__ return __self._fun_wrapper(False, args, kwargs) File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/jit.py", line 78, in _fun_wrapper dynamic_out, static_out = self._cached(dynamic, static) File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/jit.py", line 30, in fun_wrapped out = fun(*args, **kwargs) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/integrate.py", line 858, in diffeqsolve final_state, aux_stats = adjoint.loop( File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/adjoint.py", line 499, in loop final_state, aux_stats = _loop_backsolve( File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/grad.py", line 509, in __call__ out = self.fn_wrapped( File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/grad.py", line 443, in fn_wrapped out = self.fn(vjp_arg, *args, **kwargs) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/adjoint.py", line 250, in _loop_backsolve return self._loop_fn( File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/integrate.py", line 497, in loop final_state = bounded_while_loop( File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/misc/bounded_while_loop.py", line 125, in bounded_while_loop return lax.while_loop(cond_fun, _body_fun, init_val) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/misc/bounded_while_loop.py", line 118, in _body_fun _new_val = body_fun(_val, inplace) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/integrate.py", line 137, in body_fun (y, y_error, dense_info, solver_state, solver_result) = solver.step( File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/solver/semi_implicit_euler.py", line 42, in step y0_1, y0_2 = y0 ValueError: too many values to unpack (expected 2)

    bug feature 
    opened by Chenghao-Wu 1
Releases(v0.2.2)
  • v0.2.2(Nov 15, 2022)

    Performance improvements

    • Now make fewer vector field traces in several cases (#172, #174)

    Fixes

    • Many documentation improvements.
    • Fixed several warnings about jax.{tree_map,tree_leaves,...} being moved to jax.tree_util.{tree_map,tree_leaves,...}. (Thanks @jacobusmmsmit!)
    • Fixed the step size controller choking if the error is ever NaN. (#143, #152)
    • Fixed some crashes due to JAX-internal changes (If you've ever seen it throw an error about not knowing how to rewrite closed_call_p, it's this one.)
    • Fixed an obscure edge-case NaN on the backward pass, if you were using an implicit solver with an adaptive step size controller, got a rejected step due to the implicit solve failing to converge, and happened to also be backpropagating wrt the controller_state.

    Other

    • Added a new Kalman filter example (#159) (Thanks @SimiPixel!)
    • Brownian motion classes accept pytrees for shape and dtype arguments (#183) (Thanks @ciupakabra!)
    • The main change is an internal refactor: a lot of functionality has moved diffrax.misc -> equinox.internal.

    New Contributors

    • @jacobusmmsmit made their first contribution in https://github.com/patrick-kidger/diffrax/pull/149
    • @SimiPixel made their first contribution in https://github.com/patrick-kidger/diffrax/pull/159
    • @ciupakabra made their first contribution in https://github.com/patrick-kidger/diffrax/pull/183

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.2.1...v0.2.2

    Source code(tar.gz)
    Source code(zip)
  • v0.2.1(Aug 3, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Made is_okay,is_successful,is_event public by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/134
    • Fix implicit adjoints assuming array-valued state by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/136
    • Replace jax tree manipulation method that are being deprecated with jax.tree_util equivalents by @mahdi-shafiei in https://github.com/patrick-kidger/diffrax/pull/138
    • bump version by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/141

    New Contributors

    • @mahdi-shafiei made their first contribution in https://github.com/patrick-kidger/diffrax/pull/138

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.2.0...v0.2.1

    Source code(tar.gz)
    Source code(zip)
  • v0.2.0(Jul 20, 2022)

    • Feature: event handling. In particular it is now possible to interrupt a diffeqsolve early. See the events page in the docs and the new steady state example.
    • Compilation time improvements:
      • The compilation speed of NewtonNonlinearSolver (and thus in practice also all implicit solvers like Kvaerno3 etc.) has been improved (~factor 1.5)
      • The compilation speed of all Runge--Kutta solvers can be dramatically reduced (~factor 3) by passing e.g. Dopri5(scan_stages=True). This may increase runtime slightly. At the moment the default is scan_stages=False for all solvers, but this default might change in the future.
    • Various documentation improvements.

    New Contributors

    • @jatentaki made their first contribution in https://github.com/patrick-kidger/diffrax/pull/121

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.1.2...v0.2.0

    Source code(tar.gz)
    Source code(zip)
  • v0.1.2(May 18, 2022)

    Main change here is a minor technical one - Diffrax will no longer initialise the JAX backend as a side effect of being imported.


    Autogenerated release notes as follows:

    What's Changed

    • Removed explicit jaxlib dependency by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/93
    • switch error_if to python if (regarding google/jax/issues/10047) by @amir-saadat in https://github.com/patrick-kidger/diffrax/pull/99
    • Doc fixes by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/100
    • Bump version by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/107

    New Contributors

    • @amir-saadat made their first contribution in https://github.com/patrick-kidger/diffrax/pull/99

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.1.1...v0.1.2

    Source code(tar.gz)
    Source code(zip)
  • v0.1.1(Apr 7, 2022)

    Diffrax uses some JAX-internal functionality that will shortly be deprecated in JAX. This release adds the appropriate support for both older and newer versions of JAX.


    Autogenerated release notes as follows:

    What's Changed

    • [JAX] Add MHLO lowerings in preparation for xla.lower_fun() removal by @hawkinsp in https://github.com/patrick-kidger/diffrax/pull/91
    • Bump version by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/92

    New Contributors

    • @hawkinsp made their first contribution in https://github.com/patrick-kidger/diffrax/pull/91

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.1.0...v0.1.1

    Source code(tar.gz)
    Source code(zip)
  • v0.1.0(Mar 30, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Adjusted PIDController by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/89

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.0.6...v0.1.0

    Source code(tar.gz)
    Source code(zip)
  • v0.0.6(Mar 29, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Symbolic regression text by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/79
    • Fixed edge case infinite loop on stiff-ish problems (+very bad luck) by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/86

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.0.5...v0.0.6

    Source code(tar.gz)
    Source code(zip)
  • v0.0.5(Mar 21, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Doc tweaks by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/72
    • Added JIT wrapper to stiff ODE example by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/75
    • Added autoreleases by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/78
    • Removed overheads from runtime checking when they can be compiled away. by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/77

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.0.4...v0.0.5

    Source code(tar.gz)
    Source code(zip)
  • v0.0.4(Mar 6, 2022)

    First release using GitHub releases! We'll be using this to serve as a changelog.

    As for what has changed since the v0.0.3 release, we'll let the autogenerated release notes do the talking:

    What's Changed

    • Rewrote RK implementation quite substantially to allow FSAL RK SDE integrators. by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/70

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.0.3...v0.0.4

    Source code(tar.gz)
    Source code(zip)
Owner
Patrick Kidger
Maths+ML PhD student at Oxford. Neural ODEs+SDEs+CDEs, time series, rough analysis. (Also ice skating, martial arts and scuba diving!)
Patrick Kidger
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022
An example showing how to use jax to train resnet50 on multi-node multi-GPU

jax-multi-gpu-resnet50-example This repo shows how to use jax for multi-node multi-GPU training. The example is adapted from the resnet50 example in d

Yangzihao Wang 20 Jul 4, 2022
Multiple types of NN model optimization environments. It is possible to directly access the host PC GUI and the camera to verify the operation. Intel iHD GPU (iGPU) support. NVIDIA GPU (dGPU) support.

mtomo Multiple types of NN model optimization environments. It is possible to directly access the host PC GUI and the camera to verify the operation.

Katsuya Hyodo 24 Mar 2, 2022
High performance Cross-platform Inference-engine, you could run Anakin on x86-cpu,arm, nv-gpu, amd-gpu,bitmain and cambricon devices.

Anakin2.0 Welcome to the Anakin GitHub. Anakin is a cross-platform, high-performance inference engine, which is originally developed by Baidu engineer

null 514 Dec 28, 2022
GrabGpu_py: a scripts for grab gpu when gpu is free

GrabGpu_py a scripts for grab gpu when gpu is free. WaitCondition: gpu_memory >

tianyuluan 3 Jun 18, 2022
An abstraction layer for mathematical optimization solvers.

MathOptInterface Documentation Build Status Social An abstraction layer for mathematical optimization solvers. Replaces MathProgBase. Citing MathOptIn

JuMP-dev 284 Jan 4, 2023
A fuzzing framework for SMT solvers

yinyang A fuzzing framework for SMT solvers. Given a set of seed SMT formulas, yinyang generates mutant formulas to stress-test SMT solvers. yinyang c

Project Yin-Yang for SMT Solver Testing 145 Jan 4, 2023
A module for solving and visualizing Schrödinger equation.

qmsolve This is an attempt at making a solid, easy to use solver, capable of solving and visualize the Schrödinger equation for multiple particles, an

null 506 Dec 28, 2022
Finite difference solution of 2D Poisson equation. Can handle Dirichlet, Neumann and mixed boundary conditions.

Poisson-solver-2D Finite difference solution of 2D Poisson equation Current version can handle Dirichlet, Neumann, and mixed (combination of Dirichlet

Mohammad Asif Zaman 34 Dec 23, 2022
2D Time independent Schrodinger equation solver for arbitrary shape of well

Schrodinger Well Python Python solver for timeless Schrodinger equation for well with arbitrary shape https://imgur.com/a/jlhK7OZ Pictures of circular

WeightAn 24 Nov 18, 2022
Cweqgen - The CW Equation Generator

The CW Equation Generator The cweqgen (pronouced like "Queck-Jen") package provi

null 2 Jan 15, 2022
PINN Burgers - 1D Burgers equation simulated by PINN

PINN(s): Physics-Informed Neural Network(s) for Burgers equation This is an impl

ShotaDEGUCHI 1 Feb 12, 2022
Linear algebra python - Number of operations and problems in Linear Algebra and Numerical Linear Algebra

Linear algebra in python Number of operations and problems in Linear Algebra and

Alireza 5 Oct 9, 2022
Numerical Methods with Python, Numpy and Matplotlib

Numerical Bric-a-Brac Collections of numerical techniques with Python and standard computational packages (Numpy, SciPy, Numba, Matplotlib ...). Diffe

Vincent Bonnet 10 Dec 20, 2021
This repository contains numerical implementation for the paper Intertemporal Pricing under Reference Effects: Integrating Reference Effects and Consumer Heterogeneity.

This repository contains numerical implementation for the paper Intertemporal Pricing under Reference Effects: Integrating Reference Effects and Consumer Heterogeneity.

Hansheng Jiang 6 Nov 18, 2022
Genetic Algorithm, Particle Swarm Optimization, Simulated Annealing, Ant Colony Optimization Algorithm,Immune Algorithm, Artificial Fish Swarm Algorithm, Differential Evolution and TSP(Traveling salesman)

scikit-opt Swarm Intelligence in Python (Genetic Algorithm, Particle Swarm Optimization, Simulated Annealing, Ant Colony Algorithm, Immune Algorithm,A

郭飞 3.7k Jan 3, 2023
PyTorch implementation for SDEdit: Image Synthesis and Editing with Stochastic Differential Equations

SDEdit: Image Synthesis and Editing with Stochastic Differential Equations Project | Paper | Colab PyTorch implementation of SDEdit: Image Synthesis a

null 536 Jan 5, 2023