Hardware accelerated, batchable and differentiable optimizers in JAX.

Overview

JAXopt

Installation | Examples | References

Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.

Installation

JAXopt can be installed with pip directly from github, with the following command:

$ pip install git+https://github.com/google/jaxopt

Alternatively, it can be be installed from sources with the following command:

$ python setup.py install

References

Our implicit differentiation framework is described in this paper. To cite it:

@article{jaxopt_implicit_diff,
  title={Efficient and Modular Implicit Differentiation},
  author={Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy and Hoyer, Stephan and Llinares-L{\'o}pez, Felipe and Pedregosa, Fabian and Vert, Jean-Philippe},
  journal={arXiv preprint arXiv:2105.15183},
  year={2021}
}

Disclaimer

JAXopt is an open source project maintained by a dedicated team in Google Research, but is not an official Google product.

Comments
  • Levenberg-Mardquat Running exceptionally slow unless verbose is enabled

    Levenberg-Mardquat Running exceptionally slow unless verbose is enabled

    Levenberg-Mardquat optimizer runs exceptionally slow (~30 seconds for 30 iterations) until I turn on verbose==True (~1 second for 30 iterations. Any idea what may be going on? enabling JIT seems to have no impact. Was hoping to use this for a real-time system but even at 1 second things are way too slow.

    opened by pablovela5620 23
  • implementation of Fletcher-Reeves Algorithm

    implementation of Fletcher-Reeves Algorithm

    • Polak-Ribiere Method; To my knowledge, it was quite successful to use conjugate gradient variants on general nonconstrained optimization

    This PR depends on Line Search of PR #128.

    • Beta division is required to guarantee strong Wolfe Condition, but (i don't know) it raises error..
    pull ready 
    opened by ita9naiwa 17
  • vmap support in QPs

    vmap support in QPs

    Hi, I experience some pb with projection_polyhedron

    import numpy as np
    import matplotlib.pyplot as plt
    
    import jax
    import jax.numpy as jnp
    
    import jaxopt
    from jaxopt.projection import projection_l2_ball, projection_box, projection_l1_ball, projection_polyhedron
    
    def myproj3(x):
        A = jnp.array([[1.0, 1.0]])
        b = jnp.array([1.0])
        G = jnp.array([[-1.0, 0.0], [0.0, -1.0]])
        h = jnp.array([0.0, 0.0])    
        x = projection_polyhedron(x,hyperparams = (A, b, G, h))
        return x
    
    rng_key = jax.random.PRNGKey(42)
    x = jax.random.uniform(rng_key, (5000,2), minval=-3,maxval=3)
    p1_x=jax.vmap(myproj3, in_axes=(0,None))(x)
    fig, ax = plt.subplots(figsize=(5,5))
    ax.scatter(x[:,0],x[:,1],s=0.5)
    ax.scatter(p1_x[:,0],p1_x[:,1],s=0.5,c='g')
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    plot.show()
    

    First, I had to install cvxpy #!pip install cvxpy Then, I got this error

    TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float64[2])>with<BatchTrace(level=1/1)>
      with val = DeviceArray([[-2.37103211,  2.33759997],
                              [ 2.76953806, -2.37750394],
                              [-0.87246632,  0.73224625],
                              ...,
                              [ 2.29799773,  2.81894884],
                              [ 2.4022714 ,  0.80693103],
                              [-0.41563116,  2.83898531]], dtype=float64)
           batch_dim = 0
    

    Is anyone has an hint? Thanks

    enhancement 
    opened by jecampagne 12
  • KKT conditions when the primal solution is a pytree

    KKT conditions when the primal solution is a pytree

    Hi, Congrats on the great tool! Inspired by the QuadraticProgramming example I built a code that differentiates through KKT conditions. My code works whenever the primal solution variable is a jnp array, but not when it's a generic pytree. Giving me the following issue:

    TypeError: Tree structure of cotangent input PyTreeDef(([(*, *), (), (*, *)], *, None)), does not match structure of primal output PyTreeDef(([(*, *), (), (*, *), (*, *), (), (*, *)], *, None))

    where I'm pretty sure [(*, *), (), (*, *)] represents the primal solution and PyTreeDef(([(*, *), (), (*, *)], *, None)) could represent the optimality function.

    I was able to make it work by storing the primal solution in a single jnp array and reshaping it into the appropriate pytree whenever needed, but it's not clean or efficient. I was wondering if there's a bug in the current codebase (I only found tests for single jnp arrays) or I'm misusing the interface (I'm not a jax expert).

    To make it easier to reproduce I modified the quadratic_prog.py file by making the model return a list of one array instead of an array for the primal variables (leaving both dual variables the same). Then I modified the obj_fun, eq_fun and ineq_fun to use primal_var[0] instead of primal_var. If I understand correctly, this should still work. However, it doesn't, this test line raises an assert for an array that should be all zeros and instead is: ([DeviceArray([ 0.43999994, -1.3199999 ], dtype=float32), DeviceArray([-0.44000003, 1.32 ], dtype=float32)], DeviceArray([2.9802322e-08], dtype=float32), None)

    Looking at the numbers of the problem I believe [0.44,-1.32] is the gradient of the obj_fun w.r.t. the primal and [-0.44,+1.32] the gradient of the equality constraint w.r.t. the primal times the dual. They should have been added up together to have [0,0] as expected. I feel this may be fundamentally the same problem I was facing in my own research code since there I also found one of the values had the shape of the primal variable twice instead of once.

    Notice also thatthe test on the line just above (checking that the primal solution is correct) still holds provided we check sol[0][0] instead of sol[0] (since sol[0] is now a 1-element list).

    Is differentiation through KKT supposed to work for general pytrees? If so, what should I have done to make it work in the quadratic_prog.py example?

    Thanks!

    opened by FerranAlet 11
  • Hot fix: corrected condition in lbfgs

    Hot fix: corrected condition in lbfgs

    The feature I had introduced in https://github.com/google/jaxopt/pull/323 was failing when the run function was jitted and was a no-op when not because of the following reason:

     ~True == -2  # this is True
    

    Therefore when jitted it was complaining about different types in a condition function, and when not jitted it was equivalent to always being False.

    EDIT

    Actually I am still running into an error when jitted, so will continue to investigate.

    The gist of the error is Abstract tracer value encountered where concrete value is expected, basically doing (not self.stop_if_linesearch_fails | ~state.failed_linesearch) is not allowed because one is a bool and the other is an abstract value.

    pull ready 
    opened by zaccharieramzi 7
  • Issue with gradients wrt optimality fn parameters through root finding vjp

    Issue with gradients wrt optimality fn parameters through root finding vjp

    First of all, thanks a lot for this library! Really useful tools! I'm interested in getting at least 2nd order gradients through root finding, and I'm finding an odd behavior that I wanted to report.

    Maybe I'm doing something wrong, but in the following schematic case I silently get the wrong gradients:

    def inv_f(x, aux):
      bisec = Bisection(optimality_fun=F, lower=0.0, upper=1., 
                        check_bracket=False, unroll=True)
      return bisec.run(aux=aux).params
    
    # Here I extract the value part of the vjp, but the grad part also gives wrong results
    test_fn = lambda aux: jax.value_and_grad(inv_f)(0.5, aux)[0] 
    
    jax.grad(test_fn)(1.) # Returns 0 instead of the expected gradients
    

    Here I'm only trying to get gradients of the value returned by jax.value_and_grad, but the gradients of the gradients returned by jax.value_and_grad are also wrong (but not as obvious).

    I made a small demo notebook that reproduces this issue here.

    As a reference I've also implemented my own implicit gradients, bypassing the jaxopt ones, and they seem to give me the correct answer.

    Reading the source code of jaxopt, it is not immediatly obvious to me why this doesn't work... Sorry I couldn't directly suggest a PR, but I hope this report is still useful (and that I'm not just using jaxopt wrong).

    bug 
    opened by EiffL 7
  • misc improvements to robust training example

    misc improvements to robust training example

    main changes:

    • Fixes #134 by normalizing in-place.
    • Plot convergence curves for both clean and adversarial accuracy.
    • Replace the fast-sign-gradient method by the much more powerful PGD method.
    • Be able to select different datasets.
    • Homogeneize API wrt to the other examples. For example, this now uses the same load_dataset, CNN, loss_fun, accuracy than flax_image_classif.py . Most of the command line flags have also been homogeneized.
    pull ready 
    opened by fabianp 7
  • Bisection hanging

    Bisection hanging

    I am trying to jaxopt.Bisection to replace the use of scipy.optimize.bisect in a computational model but Bisection hangs when I run my code.

    The basic structure includes 2 functions that are both jitted (so I assume it should be able to compile ok):

    @jit
    def f1(parameters):
        ....
        return jax.numpy.array([a,b,c])
    
    @jit
    def opt_fun(x):
        f1(x,params)
        .... 
        return float_value
    

    when I call scipy.optimize.bisect(opt_fun,x0,x1) it runs with no issue but jaxopt.Bisection(opt_fun,x0,x1).run(None) hangs with with~10% cpu usage and55% memory usage on i9 2018 macbook pro with 32GB of memory.

    I acknowledge I may be using this incorrectly and that this is possibly not the intended use case but any direction would be very helpful. My intention is to use this computational model with numpyro in the future and having a jax version of the bisection root finding would be incredibly helpful.

    opened by jjruby09 7
  • Incompatible shape in solve_normal_cg

    Incompatible shape in solve_normal_cg

    When A.shape = (N, P) for N != P, I run into shape errors when trying to use solve_normal_cg for fitting the normal equations.

    I have a small reproducible example below for N > P, but the error holds for when P > N.

    import jax.numpy as jnp
    import numpy as np
    N = 1000
    P = 3
    prob = np.random.uniform(0.01, 0.5, size=P)
    h2g = 0.1
    X = np.random.binomial(2, p=prob, size=(N, P))
    b = np.random.normal(size=(P)) * np.sqrt(h2g / P)
    y = X @ b + np.sqrt(1 - h2g) * np.random.normal(size=(N,))
    
    import jaxopt as jopt
    jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
    WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    Input In [11], in <module>
    ----> 1 jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
    
    File ~/miniconda3/lib/python3.9/site-packages/jaxopt/_src/linear_solve.py:151, in solve_normal_cg(matvec, b, ridge, init, **kwargs)
        148 if ridge is not None:
        149   _matvec = _make_ridge_matvec(_matvec, ridge=ridge)
    --> 151 Ab = _rmatvec(matvec, b)
        153 return jax.scipy.sparse.linalg.cg(_matvec, Ab, x0=init, **kwargs)[0]
    
    File ~/miniconda3/lib/python3.9/site-packages/jaxopt/_src/linear_solve.py:114, in _rmatvec(matvec, x)
        112 def _rmatvec(matvec, x):
        113   """Computes A^T x, from matvec(x) = A x, where A is square."""
    --> 114   transpose = jax.linear_transpose(matvec, x)
        115   return transpose(x)[0]
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/_src/api.py:2211, in linear_transpose(fun, reduce_axes, *primals)
       2208 in_dtypes = map(dtypes.dtype, in_avals)
       2210 in_pvals = map(pe.PartialVal.unknown, in_avals)
    -> 2211 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(flat_fun, in_pvals,
       2212                                              instantiate=True)
       2213 out_avals, _ = unzip2(out_pvals)
       2214 out_dtypes = map(dtypes.dtype, out_avals)
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:505, in trace_to_jaxpr(fun, pvals, instantiate)
        503 with core.new_main(JaxprTrace) as main:
        504   fun = trace_to_subjaxpr(fun, main, instantiate)
    --> 505   jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
        506   assert not env
        507   del main, fun, env
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/linear_util.py:166, in WrappedFun.call_wrapped(self, *args, **kwargs)
        163 gen = gen_static_args = out_store = None
        165 try:
    --> 166   ans = self.f(*args, **dict(self.params, **kwargs))
        167 except:
        168   # Some transformations yield from inside context managers, so we have to
        169   # interrupt them before reraising the exception. Otherwise they will only
        170   # get garbage-collected at some later time, running their cleanup tasks only
        171   # after this exception is handled, which can corrupt the global state.
        172   while stack:
    
    Input In [11], in <lambda>(x)
    ----> 1 jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4196, in dot(a, b, precision)
       4194   return lax.mul(a, b)
       4195 if _max(a_ndim, b_ndim) <= 2:
    -> 4196   return lax.dot(a, b, precision=precision)
       4198 if b_ndim == 1:
       4199   contract_dims = ((a_ndim - 1,), (0,))
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/_src/lax/lax.py:667, in dot(lhs, rhs, precision, preferred_element_type)
        664   return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
        665                      precision=precision, preferred_element_type=preferred_element_type)
        666 else:
    --> 667   raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
        668       lhs.shape, rhs.shape))
    
    TypeError: Incompatible shapes for dot: got (1000, 3) and (1000,).
    
    opened by quattro 6
  • Initial stepsize not exposed in LBFGS constructor [question/bug?]

    Initial stepsize not exposed in LBFGS constructor [question/bug?]

    I see that LbfgsState contains a stepsize and that LBFGS.init_state hard-codes it to 1. I also see that the LBFGS.update method performs a line search in which the initial step size is set from this LBFGS state.

    I have a particularly ill-conditioned problem that requires tiny initial steps, but I was surprised that the initial stepsize could not be set in the LBFGS constructor or elsewhere as far as I could see. Is this an oversight or an intentional part of the design? If it's intentional, is there an idiomatic way to set an initial stepsize when using LBFGS.run that I have overlooked?

    Thanks in advance, and thanks for a really cool library.

    opened by erdmann 6
  • Infinities and NaNs in quadratic_prog when c=0

    Infinities and NaNs in quadratic_prog when c=0

    Hi,

    I'm using QuadraticProgramming in the special case of c=0 (all zeros as a vector). AFAIK this is still well-defined, as it's just minimizing l2 norm squared of the primal subject to some equality constraints (I don't have inequalities).

    However, both my research code and the following modification of this test diverge even for a single step (maxiter=1).

    The modification just involves setting c=0, so:

    def test_qp_eq_only_c_zero(self):
      Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
      c = jnp.array([0.0, 0.0]) #ONLY CHANGE
      A = jnp.array([[1.0, 1.0]])
      b = jnp.array([1.0])
      qp = QuadraticProgramming(tol=1e-7)
      hyperparams = dict(params_obj=(Q, c), params_eq=(A, b))
      sol = qp.run(**hyperparams).params
      self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
      self._check_derivative_A_and_b(qp, hyperparams, A, b)
    

    Is there a way to fix it? If it involves calling another linear solver, is there a way to specify the solver from the high-level QP function? I haven't seen it.

    Thanks!

    opened by FerranAlet 6
  • OptaxSolver Error: too many positional arguments

    OptaxSolver Error: too many positional arguments

    Hello! I tried to implement the example of implicit differentiation as shown here but with my own functions. The task is to find mean for a set of vectors named X via gradient descent.

    import numpy as np
    import matplotlib.pyplot as plt
    from tqdm import tqdm
    
    import jax
    import jax.numpy as jnp
    from jax import grad, random, jit
    from jax import jacobian, hessian, jacfwd, jacrev
    key = random.PRNGKey(0)
    
    import jaxopt
    from jaxopt import implicit_diff
    from jaxopt import linear_solve
    from jaxopt import OptaxSolver, GradientDescent
    import optax
    
    def euclidean_distance(a, b):
        """
        Squared Euclidean distance
        """
        return jnp.inner(a - b, a - b)
    
    def weighted_distance(x, X, w):
        loss = 0
        for i, obj in enumerate(X):
            loss += w[i] * euclidean_distance(obj, x)
        return loss
    
    def identical(Y, Y_grad):
        return Y
    

    Algorithm for finding mean:

    # Mean calculation for manifolds with gradient descent
    @implicit_diff.custom_root(jax.grad(weighted_distance))
    def euclidean_weighted_mean(X_set, weights = None, lr = 0.1, n_iter = 50, plot_loss_flag = False):
        
        if weights == None:
            weights = jnp.full((X_set.shape[0]), 1) / X_set.shape[0]
    
        # init mean with random element from set
        Y = X_set[np.random.randint(0, X_set.shape[0], (1,))][0] 
        
        if plot_loss_flag:
            plot_loss = []
            prev_loss = 0
            plato_iter = 0
            plato_reached = False
        
        for i in range(n_iter):
            
            # calculate loss
            loss = weighted_distance(Y, X_set, weights)
    
            if plot_loss_flag:
                if jnp.allclose(jnp.array(loss), jnp.array(prev_loss)):
                    if not plato_reached:
                        plato_iter = i
                        plato_reached = True
                else:
                    prev_loss = loss
                    plato_reached = False
        
            Y_grad = grad(weighted_distance, argnums= 0)(Y, X_set, weights)
            
            # calculate Riemannian gradient
            riem_grad_Y = Y_grad
            
            # update Y
            Y_step = Y - lr * riem_grad_Y
            
            # project new Y on manifold with retraction
            Y = Y_step
            
            if plot_loss_flag:
              # collect loss for plotting
              plot_loss.append(loss)
        
        if plot_loss_flag:
            print(f"Total loss: {weighted_distance(Y, X_set, weights)} got in {plato_iter} iterations")    
            fig, ax = plt.subplots()
            ax.plot(plot_loss)
            ax.set_xlabel("Iteration")
            ax.set_ylabel("Loss")
            plt.show()
        return Y
    

    You can launch it like this:

    d = 2
    m = 4
    X = jax.random.uniform(key, (m,d))
    euclidean_weighted_mean(X, weights = None, lr = 1e-3, n_iter = 100, plot_loss_flag = True)
    

    As you can see, I am calculating the weighted version of mean and that's where I use jaxopt. Let me define the global objective (just as an example): I want the weights have the value, which minimises the distance between the resulting mean and the desired point. In my case, I want the weights to influence the algorithm in such a way, that the resulting mean will be as close to X[0] as possible:

    def global_task_objective(w, X, target_point, lr, n_iter):
        x = euclidean_weighted_mean(X, w, lr = lr, n_iter = n_iter)
        loss = euclidean_distance(x, target_point)
        return loss, x
    
    target_point = X[0]
    
    w_init = jnp.array(np.random.randn(X.shape[0])) * jnp.square(2 / X.shape[0]) 
    
    lr = 1e-3
    n_iter = 100
    
    global_task_objective(w_init, X, target_point, lr, n_iter)
    solver = OptaxSolver(opt=optax.amsgrad(1e-2), fun=global_task_objective, has_aux=True)
    state = solver.init_state(w_init, X=X, target_point=target_point, lr=lr, n_iter=n_iter)
    

    The problem emerges when I call

    w_init, state = solver.update(params=w_init, 
                                 state=state, 
                                 X=X, target_point=target_point, lr=lr, n_iter=n_iter)
    
    image Meanwhile the official example with Ridge regression works perfectly. Any suggestions?
    opened by MarioAuditore 0
  • Custom loop pjit example

    Custom loop pjit example

    A MWE of how jax.experimental.pjit can be used in JAXopt (see also PR #346).

    NOTE: jax.experimental.pjit is not yet supported in Colab. However, this example illustrates how users with access to Google Cloud TPUs may use jax.experimental.pjit in combination with JAXopt solvers.

    pull ready 
    opened by fllinares 2
  • Added a new API allowing to warm start the inverse Hessian approximation in LBFGS

    Added a new API allowing to warm start the inverse Hessian approximation in LBFGS

    This fixes #351 .

    @mblondel I couldn't use your suggestion of creating a new type of init LBFGSInit because the init_params variable is used for both init_state and update. Therefore I would have had to add case distinctions in the 2 functions which seemed unreasonable. Rather I took the approach I saw in some other iterative solvers which was to add an extra keyword argument to init_state, update and _value_and_grad_fun.

    I added a test to make sure that this runs, but I am not sure whether we need to add a test to make sure that it improves some cases. I also don't know whether we should test that differentiation is ok.

    opened by zaccharieramzi 5
  • Enable warm-starting the hessian approximation in L-BFGS

    Enable warm-starting the hessian approximation in L-BFGS

    Currently one can only provide an initial estimate of the solution, enable warm start of the iterates. But for quasi-Newton methods, it can also be a good idea to provide initial estimates of the hessian approximation, typically when solving multiple time a similar problem.

    This was for example done in HOAG by @fabianp (see https://github.com/fabianp/hoag/blob/master/hoag/hoag.py#L109).

    I am willing to implement this in the next few weeks.

    As I know it is of interest to them as well, cc-ing @marius311 and @mblondel

    opened by zaccharieramzi 2
  • Batched QP (and other optimization algorithm)

    Batched QP (and other optimization algorithm)

    I'm trying to make OSQP batchable (so I can make it a layer in neural networks, like OptNet), but I couldn't find any documentation yet about using vmap to solve batched version of optimization problems.

    opened by jn-tang 1
Releases(jaxopt-v0.5.5)
  • jaxopt-v0.5.5(Oct 20, 2022)

    New features

    • Added MAML example by Fabian Pedregosa based on initial code by Paul Vicol and Eric Jiang.
    • Added the possibility to stop LBFGS after a line search failure, by Zaccharie Ramzi.
    • Added gamma to LBFGS state, by Zaccharie Ramzi.
    • Added jaxopt.BFGS, by Mathieu Blondel.
    • Added value_and_grad option to all gradient-based solvers, by Mathieu Blondel.
    • Added Fenchel-Young loss, by Quentin Berthet.
    • Added projection_sparse_simplex, by Tianlin Liu.

    Bug fixes and enhancements

    • Fixed missing args,kwargs in resnet example, by Louis Béthune.
    • Corrected the implicit diff examples, by Zaccharie Ramzi.
    • Small optimization in l2-regularized semi-dual OT, by Mathieu Blondel.
    • Numerical stability improvements in jaxopt.LevenbergMarquardt, by Amir Saadat.
    • Dtype consistency in LBFGS, by Alex Botev.

    Deprecations

    • jaxopt.QuadraticProgramming is now fully removed. Use jaxopt.CvxpyQP, jaxopt.OSQP, jaxopt.BoxOSQP and jaxopt.EqualityConstrainedQP instead.

    Contributors

    Alex Botev, Amir Saadat, Fabian Pedregosa, Louis Béthune, Mathieu Blondel, Quentin Berthet, Tianlin Liu, Zaccharie Ramzi.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.5(Aug 30, 2022)

    New features

    • Added optimal transport related projections: projection_transport, projection_birkhoff, kl_projection_transport, and kl_projection_birkhoff, by Mathieu Blondel (semi-dual formulation) and Tianlin Liu (dual formulation).

    Bug fixes and enhancements

    • Fix LaTeX rendering issue in notebooks, by Amélie Héliou.
    • Avoid gradient recompilations in zoom line search, by Mathieu Blondel.
    • Fix unused Jacobian issue in jaxopt.ScipyRootFinding, by Louis Béthune.
    • Use zoom line search by default in jaxopt.LBFGS and jaxopt.NonlinearCG, by Mathieu Blondel.
    • Pass tolerance argument to jaxopt.ScipyMinimize, by pipme.
    • Handle has_aux in jaxopt.LevenbergMarquardt, by Keunhong Park.
    • Add maxiter keyword argument in jaxopt.ScipyMinimize, by Fabian Pedregosa.

    Contributors

    Louis Béthune, Mathieu Blondel, Amélie Héliou, Keunhong Park, Fabian Pedregosa, pipme.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.4.3(Jun 28, 2022)

    New features

    • Added zoom line search in jaxopt.LBFGS, by Mathieu Blondel. It can be enabled with the linesearch="zoom" option.

    Bug fixes and enhancements

    • Added support for quadratic polynomial fun in jaxopt.BoxOSQP and jaxopt.OSQP, by Louis Béthune.
    • Added a notebook for the dataset distillation example, by Amélie Héliou.
    • Fixed wrong links and deprecation warnings in notebooks, by Fabian Pedregosa.
    • Changed losses to avoid roundoff, by Jack Valmadre.
    • Fixed init_params bug in multiclass_svm example, by Louis Béthune.

    Contributors

    Louis Béthune, Mathieu Blondel, Amélie Héliou, Fabian Pedregosa, Jack Valmadre.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.4.2(Jun 10, 2022)

  • jaxopt-v0.4.1(Jun 10, 2022)

    Bug fixes and enhancements

    • Improvements in jaxopt.LBFGS: fixed bug when using use_gamma=True, added stepsize option, strengthened tests, by Mathieu Blondel.
    • Fixed link in resnet notebook, by Fabian Pedregosa.

    Contributors

    Fabian Pedregosa, Mathieu Blondel.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.4(May 24, 2022)

    New features

    • Added solver jaxopt.LevenbergMarquardt, by Amir Saadat.
    • Added solver jaxopt.BoxCDQP, by Mathieu Blondel.
    • Added projection_hypercube, by Mathieu Blondel.

    Bug fixes and enhancements

    • Fixed solve_normal_cg when the linear operator is “nonsquare” (does not map to a space of same dimension), by Mathieu Blondel.
    • Fixed edge case in jaxopt.Bisection, by Mathieu Blondel.
    • Replaced deprecated tree_multimap with tree_map, by Fan Yang.
    • Added support for leaf cond pytrees in tree_where, by Felipe Llinares.
    • Added Python 3.10 support officially, by Jeppe Klitgaard.
    • In scipy wrappers, converted pytree leaves to jax arrays to determine their shape/dtype, by Roy Frostig.
    • Converted the “Resnet” and “Adversarial Training” examples to notebooks, by Fabian Pedregosa.

    Contributors

    Amir Saadat, Fabian Pedregosa, Fan Yang, Felipe Llinares, Jeppe Klitgaard, Mathieu Blondel, Roy Frostig.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.3.1(Feb 28, 2022)

    New features

    • Pjit-based example of data parallel training using Flax, by Felipe Llinares.

    Bug fixes and enhancements

    • Support for GPU and state of the art adversarial training algorithm (PGD) on the robust_training.py example, by Fabian Pedregosa.
    • Update line search in LBFGS to use jit and unroll from LBFGS, by Ian Williamson.
    • Support dynamic maximum iteration count in iterative solvers, by Roy Frostig.
    • Fix tree_where for singleton pytrees, by Louis Béthune.
    • Remove QuadraticProg in projections and set init_params=None by default in QP solvers, by Louis Béthune.
    • Add missing 'value' attribute in LbfgsState, by Mathieu Blondel.

    Contributors

    Felipe Llinares, Fabian Pedregosa, Ian Williamson, Louis Bétune, Mathieu Blondel, Roy Frostig.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.3(Jan 31, 2022)

    New features

    • jaxopt.LBFGS
    • jaxopt.BacktrackingLineSearch
    • jaxopt.GaussNewton
    • jaxopt.NonlinearCG

    Bug fixes and enhancements

    • Support implicit AD in higher-order differentiation.

    Contributors

    Amir Saadat, Fabian Pedregosa, Geoffrey Négiar, Hyunsung Lee, Mathieu Blondel, Roy Frostig.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.2(Dec 18, 2021)

    New features

    • Quadratic programming solvers jaxopt.CvxpyQP, jaxopt.OSQP, jaxopt.BoxOSQP and jaxopt.EqualityConstrainedQP
    • Iterative refinement

    New examples

    • Resnet example with Flax and JAXopt.

    Bug fixes and enhancements

    • Prevent recompilation of loops in solver.run if executing without jit.
    • Prevents recomputation of gradient in OptaxSolver.
    • Make solver.update jittable and ensure output states are consistent.
    • Allow Callable for the stepsize argument in jaxopt.ProximalGradient, jaxopt.ProjectedGradient and jaxopt.GradientDescent.

    Deprecated features

    • jaxopt.QuadraticProgramming is deprecated and will be removed in v0.3. Use jaxopt.CvxpyQP, jaxopt.OSQP, jaxopt.BoxOSQP and jaxopt.EqualityConstrainedQP instead.

    Contributors

    Fabian Pedregosa, Felipe Llinares, Geoffrey Negiar, Louis Bethune, Mathieu Blondel, Vikas Sindhwani.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.1.1(Oct 19, 2021)

    New features

    • Added solver jaxopt.ArmijoSGD
    • Added example Deep Equilibrium (DEQ) model in Flax with Anderson acceleration.
    • Added example Comparison of different SGD algorithms.

    Bug fixes

    • Allow non-jittable proximity operators in jaxopt.ProximalGradient
    • Raise an exception if a quadratic program is infeasible or unbounded

    Contributors

    Fabian Pedregosa, Louis Bethune, Mathieu Blondel.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.1(Oct 14, 2021)

    Classes

    • jaxopt.AndersonAcceleration
    • jaxopt.AndersonWrapper
    • jaxopt.Bisection
    • jaxopt.BlockCoordinateDescent
    • jaxopt.FixedPointIteration
    • jaxopt.GradientDescent
    • jaxopt.MirrorDescent
    • jaxopt.OptaxSolver
    • jaxopt.PolyakSGD
    • jaxopt.ProjectedGradient
    • jaxopt.ProximalGradient
    • jaxopt.QuadraticProgramming
    • jaxopt.ScipyBoundedLeastSquares
    • jaxopt.ScipyBoundedMinimize
    • jaxopt.ScipyLeastSquares
    • jaxopt.ScipyMinimize
    • jaxopt.ScipyRootFinding
    • Implicit differentiation

    Examples

    • Binary kernel SVM with intercept.
    • Image classification example with Flax and JAXopt.
    • Image classification example with Haiku and JAXopt.
    • VAE example with Haiku and JAXopt.
    • Implicit differentiation of lasso.
    • Multiclass linear SVM (without intercept).
    • Non-negative matrix factorizaton (NMF) using alternating minimization.
    • Dataset distillation.
    • Implicit differentiation of ridge regression.
    • Robust training.
    • Anderson acceleration of gradient descent.
    • Anderson acceleration of block coordinate descent.
    • Anderson acceleration in application to Picard–Lindelöf theorem.

    Contributors

    Fabian Pedregosa, Felipe Llinares, Robert Gower, Louis Bethune, Marco Cuturi, Mathieu Blondel, Peter Hawkins, Quentin Berthet, Roy Frostig, Ta-Chu Kao

    Source code(tar.gz)
    Source code(zip)
Owner
Google
Google ❤️ Open Source
Google
Differentiable Optimizers with Perturbations in Pytorch

Differentiable Optimizers with Perturbations in PyTorch This contains a PyTorch implementation of Differentiable Optimizers with Perturbations in Tens

Jake Tuero 54 Jun 22, 2022
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.

Rockpool Rockpool is a Python package for developing signal processing applications with spiking neural networks. Rockpool allows you to build network

SynSense 21 Dec 14, 2022
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
Differentiable Neural Computers, Sparse Access Memory and Sparse Differentiable Neural Computers, for Pytorch

Differentiable Neural Computers and family, for Pytorch Includes: Differentiable Neural Computers (DNC) Sparse Access Memory (SAM) Sparse Differentiab

ixaxaar 302 Dec 14, 2022
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022
TLDR; Train custom adaptive filter optimizers without hand tuning or extra labels.

AutoDSP TLDR; Train custom adaptive filter optimizers without hand tuning or extra labels. About Adaptive filtering algorithms are commonplace in sign

Jonah Casebeer 48 Sep 19, 2022
Repository for open research on optimizers.

Open Optimizers Repository for open research on optimizers. This is a test in sharing research/exploration as it happens. If you use anything from thi

Ariel Ekgren 6 Jun 24, 2022
Fast and scalable uncertainty quantification for neural molecular property prediction, accelerated optimization, and guided virtual screening.

Evidential Deep Learning for Guided Molecular Property Prediction and Discovery Ava Soleimany*, Alexander Amini*, Samuel Goldman*, Daniela Rus, Sangee

Alexander Amini 75 Dec 15, 2022
NVIDIA Merlin is an open source library providing end-to-end GPU-accelerated recommender systems, from feature engineering and preprocessing to training deep learning models and running inference in production.

NVIDIA Merlin NVIDIA Merlin is an open source library designed to accelerate recommender systems on NVIDIA’s GPUs. It enables data scientists, machine

null 419 Jan 3, 2023
Numba-accelerated Pythonic implementation of MPDATA with examples in Python, Julia and Matlab

PyMPDATA PyMPDATA is a high-performance Numba-accelerated Pythonic implementation of the MPDATA algorithm of Smolarkiewicz et al. used in geophysical

Atmospheric Cloud Simulation Group @ Jagiellonian University 15 Nov 23, 2022
A project which aims to protect your privacy using inexpensive hardware and easily modifiable software

Protecting your privacy using an ESP32, an IR sensor and a python script This project, which I personally call the "never-gonna-catch-me-in-the-act-ev

null 8 Oct 10, 2022
An extremely simple, intuitive, hardware-friendly, and well-performing network structure for LiDAR semantic segmentation on 2D range image. IROS21

FIDNet_SemanticKITTI Motivation Implementing complicated network modules with only one or two points improvement on hardware is tedious. So here we pr

YimingZhao 54 Dec 12, 2022
Open source hardware and software platform to build a small scale self driving car.

Donkeycar is minimalist and modular self driving library for Python. It is developed for hobbyists and students with a focus on allowing fast experimentation and easy community contributions.

Autorope 2.4k Jan 4, 2023
Quantization library for PyTorch. Support low-precision and mixed-precision quantization, with hardware implementation through TVM.

HAWQ: Hessian AWare Quantization HAWQ is an advanced quantization library written for PyTorch. HAWQ enables low-precision and mixed-precision uniform

Zhen Dong 293 Dec 30, 2022
🏎️ Accelerate training and inference of 🤗 Transformers with easy to use hardware optimization tools

Hugging Face Optimum ?? Optimum is an extension of ?? Transformers, providing a set of performance optimization tools enabling maximum efficiency to t

Hugging Face 842 Dec 30, 2022
GPU-Accelerated Deep Learning Library in Python

Hebel GPU-Accelerated Deep Learning Library in Python Hebel is a library for deep learning with neural networks in Python using GPU acceleration with

Hannes Bretschneider 1.2k Dec 21, 2022
Accelerated deep learning R&D

Accelerated deep learning R&D PyTorch framework for Deep Learning research and development. It focuses on reproducibility, rapid experimentation, and

Catalyst-Team 3.1k Jan 6, 2023
3D ResNet Video Classification accelerated by TensorRT

Activity Recognition TensorRT Perform video classification using 3D ResNets trained on Kinetics-400 dataset and accelerated with TensorRT P.S Click on

Akash James 39 Nov 21, 2022