A Python library for differentiable optimal control on accelerators.

Related tags

Deep Learning trajax
Overview

trajax

A Python library for differentiable optimal control on accelerators.

Trajax builds on JAX and hence code written with Trajax supports JAX's transformations. In particular, Trajax's solvers:

  1. Are automatically efficiently differentiable, via jax.grad.
  2. Scale up to parallel instances via jax.vmap and jax.pmap.
  3. Can run on CPUs, GPUs, and TPUs without code changes, and support end-to-end compilation with jax.jit.
  4. Are made available from Python, written with NumPy.

In Trajax, differentiation through the solution of a trajectory optimization problem is done more efficiently than by differentiating the solver implementation directly. Specifically, Trajax defines custom differentiation routines for its solvers. It registers these with JAX so that they are picked up whenever using JAX's autodiff features (e.g. jax.grad) to differentiate functions that call a Trajax solver.

This is a research project, not an official Google product.

Trajax is currently a work in progress, maintained by a few individuals at Google Research. While we are actively using Trajax in our own research projects, expect there to be bugs and rough edges compared to commercially available solvers.

Trajectory optimization and optimal control

We consider classical optimal control tasks concerning optimizing trajectories of a given discrete time dynamical system by solving the following problem. Given a cost function c, dynamics function f, and initial state x0, the goal is to compute:

argmin(lambda X, U: sum(c(X[t], U[t], t) for t in range(T)) + c_final(X[T]))

subject to the constraint that X[0] == x0 and that:

all(X[t + 1] == f(X[t], U[t], t) for t in range(T))

There are many resources for more on trajectory optimization, including Dynamic Programming and Optimal Control by Dimitri Bertsekas and Underactuated Robotics by Russ Tedrake.

API

In describing the API, it will be useful to abbreviate a JAX/NumPy floating point ndarray of shape (a, b, …) as a type denoted F[a, b, …]. Assume n is the state dimension, d is the control dimension, and T is the time horizon.

Problem setup convention/signature

Setting up a problem requires writing two functions, cost and dynamics, with type signatures:

cost(state: F[n], action: F[d], time_step: int) : float
dynamics(state: F[n], action: F[d], time_step: int) : F[n]

Note that even if a dimension n or d is 1, the corresponding state or action representation is still a rank-1 ndarray (i.e. a vector, of length 1).

Because Trajax uses JAX, the cost and dynamics functions must be written in a functional programming style as required by JAX. See the JAX readme for details on writing JAX-friendly functional code. By and large, functions that have no side effects and that use jax.numpy in place of numpy are likely to work.

Solvers

If we abbreviate the type of the above two functions as CostFn and DynamicsFn, then our solvers have the following type signature prefix in common:

solver(cost: CostFn, dynamics: DynamicsFn, initial_state: F[n], initial_actions: F[T, d], *solver_args, **solver_kwargs): SolverOutput

SolverOutput is a tuple of (F[T + 1, n], F[T, d], float, *solver_outputs). The first three tuple components represent the optimal state trajectory, optimal control sequence, and the optimal objective value achieved, respectively. The remaining *solver_outputs are specific to the particular solver (such as number of iterations, norm of the final gradient, etc.).

There are currently four solvers provided: ilqr, scipy_minimize, cem, and random_shooting. Each extends the signatures above with solver-specific arguments and output values. Details are provided in each solver function's docstring.

Underlying the ilqr implementation is a time-varying LQR routine, which solves a special case of the above problem, where costs are convex quadratic and dynamics are affine. To capture this, both are represented as matrices. This routine is also made available as tvlqr.

Objectives

One might want to write a custom solver, or work with an objective function for any other reason. To that end, Trajax offers the optimal control objective in the form of an API function:

objective(cost: CostFn, dynamics: DynamicsFn, initial_state: F[n], actions: F[T, d]): float

Combining this function with JAX's autodiff capabilities offers, for example, a starting point for writing a first-order custom solver. For example:

def improve_controls(cost, dynamics, U, x0, eta, num_iters):
  grad_fn = jax.grad(trajax.objective, argnums=(2,))
  for i in range(num_iters):
    U = U - eta * grad_fn(cost, dynamics, U, x0)
  return U

The solvers provided by Trajax are actually built around this objective function. For instance, the scipy_minimize solver simply calls scipy.minimize.minimize with the gradient and Hessian-vector product functions derived from objective using jax.grad and jax.hessian.

Limitations

​​Just as Trajax inherits the autodiff, compilation, and parallelism features of JAX, it also inherits its corresponding limitations. Functions such as the cost and dynamics given to a solver must be written using jax.numpy in place of standard numpy, and must conform to a functional style; see the JAX readme. Due to the complexity of trajectory optimizer implementations, initial compilation times can be long.

Comments
  • packaging and setup

    packaging and setup

    Introduces a version number, requirements files, a setup definition, brief installation instructions, and a package-level init file. Also removes an unnecessary package dependency.

    opened by froystig 0
  • ILQR optimizer doesn't support 1D scalar dynamical systems

    ILQR optimizer doesn't support 1D scalar dynamical systems

    When trying to run a 1D quadratic control affine nonlinear system of the form as shown below, the ILQR implementation is unable to handle scalar valued systems and results in a dimensionality mismatch error. Please find error and code below.

    image

    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import functools
    import os
    
    import jax
    from jax import device_put
    from jax import vmap
    from jax.config import config
    import jax.numpy as np
    import numpy as onp
    
    from trajax import optimizers
    from trajax.integrators import euler
    from trajax.integrators import rk4
    import matplotlib.pyplot as plt
    
    def quadratic_nonlinear(x, u, t, params=(5, 10)):
        """
        Simple quadratic nonlinear system where we introduce reference trajectory as input
        :param x: 1D scalar state
        :param u: 1D input
        :param params: Kp, Kd gains for PD control law
        :return xdot: 1D array of shape 1
        """
        del t
        Kp, Kd = params
        r = np.squeeze(u)
        # xdot = (x ** 2 + Kp * (x - r) - Kd * rdot)/(1 - Kd)
        xdot = x ** 2 + Kp * (x - r)
        return np.array([xdot])
    
    class ILQR_test():
        """
        Testing ILQR implementation in trajax for simple nonlinear systems
        """
        def __init__(self):
            pass
    
    
        def discretize(self, type='euler', dynamics=None):
            if dynamics is not None:
                self.dynamics = dynamics
    
            self.dynamics = euler(self.dynamics, dt=0.01)
            if type != 'euler':
                self.dynamics = rk4(self.dynamics, dt=0.01)
    
    
        def testQuadNonLinear(self, maxiter):
            """
            Calling ilqr on quadratic nonlinear system with input as reference trajectory
            :param maxiter: maximum number of iterations to take in ilqr
            :return: list of ilqr fn output
            """
            horizon = 100
            dynamics = rk4(quadratic_nonlinear, dt=0.01)
    
            true_params = (100.0, 10.0, 1.0)
    
            def cost(params, state, action, t):
                final_weight, stage_weight, action_weight = params
    
                state_err = state - action
                state_cost = stage_weight * (state_err ** 2 + action ** 2)
                # action_cost = action_weight * np.squeeze(action) ** 2
                return np.where(t == horizon, final_weight * state_cost,
                                state_cost)
    
            x0 = np.array([-0.9])
            U0 = np.zeros((horizon, 1))
            X, U, obj, grad, adj, lqr_val, total_iter = optimizers.ilqr(
                functools.partial(cost, true_params), dynamics, x0, U0, maxiter)
            return [X, U, obj, grad, adj, lqr_val, total_iter]
    
    
    test = ILQR_test()
    traj_cost = []
    num_iter = [2, 30, 40, 50]
    
    for i in num_iter:
        print(i)
        # traj = test.apply_ilqr(x0=onp.random.randn(2), U=onp.random.randn(2), maxiter=i, dynamics=rk4(quadratic_nonlinear, dt=0.01))
        # traj = test.testPendulumReadmeExample(maxiter=i)
        traj = test.testQuadNonLinear(maxiter=i)
        traj_cost.append(traj[2])
    
    X = traj[0]
    U = traj[1]
    
    print(traj_cost)
    
    opened by Nusha97 0
  • Does not work w/ BRAX

    Does not work w/ BRAX

    Has anyone tried the solvers on BRAX environments? Here's what I have:

    import trajax
    import jax
    from jax import numpy as jnp
    from jax.flatten_util import ravel_pytree
    import brax
    from brax import envs
    
    def get_f_and_c(env):
        key = jax.random.PRNGKey(0)
        state = env.reset(key)
        _, x2qp = ravel_pytree(state.qp)
        def f(x, u, t):
            qp = x2qp(x)
            nqp, _ = env.sys.step(qp, u)
            return ravel_pytree(nqp)[0]
        def c(x, u, t):
            qp = x2qp(x)
            dstate = state.replace(qp=qp)
            nstate = env.step(dstate, u)
            return -nstate.reward
        return f, c
    
    env = envs.create('inverted_pendulum')
    key = jax.random.PRNGKey(0)
    state = env.reset(key)
    x_init, x2qp = ravel_pytree(state.qp)
    
    f, c = get_f_and_c(env)
    
    x, u, cost, *outputs = trajax.optimizers.ilqr(c, f, x_init, jnp.zeros([1, env.action_size]))
    

    which gives:

    TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>
    See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
    
    opened by MahanFathi 0
  • iCEM

    iCEM

    Hi, I'd like to contribute to iCEM extensions of current CEM implementation. I'd like to ask:

    • if it's not on your list?
    • if is it ok, if I extend current cem() function with additional (optional) features and hyperparameters?

    Regards,

    opened by mkolodziejczyk-piap 1
  • CEM

    CEM

    Hi, I noticed there are some problems with running CEM.

    1. Here the last index should be 7: https://github.com/google/trajax/blob/67fd5ed4867914f1ce1ab78819551c12206a7773/trajax/optimizers.py#L753 (similar in random shooting: https://github.com/google/trajax/blob/67fd5ed4867914f1ce1ab78819551c12206a7773/trajax/optimizers.py#L818)

    2. Default hyperparameters should be frozendict https://github.com/google/trajax/blob/main/trajax/optimizers.py#L685-L692

    Would you like a PR for this?

    Regards,

    opened by mkolodziejczyk-piap 1
  • scipy.optimize.minimize

    scipy.optimize.minimize

    Hi, do we have in total performance benefit from injecting jax args to scipy.optimize.minimize()? Are there any plans to extend jax.scipy.optimize.minimize() to constrained problems? Regards

    opened by soldierofhell 0
Owner
Google
Google ❤️ Open Source
Google
Differentiable Neural Computers, Sparse Access Memory and Sparse Differentiable Neural Computers, for Pytorch

Differentiable Neural Computers and family, for Pytorch Includes: Differentiable Neural Computers (DNC) Sparse Access Memory (SAM) Sparse Differentiab

ixaxaar 302 Dec 14, 2022
Pytorch Lightning Distributed Accelerators using Ray

Distributed PyTorch Lightning Training on Ray This library adds new PyTorch Lightning accelerators for distributed training using the Ray distributed

null 166 Dec 27, 2022
Pytorch Lightning Distributed Accelerators using Ray

Distributed PyTorch Lightning Training on Ray This library adds new PyTorch Lightning plugins for distributed training using the Ray distributed compu

null 167 Jan 2, 2023
An exploration of log domain "alternative floating point" for hardware ML/AI accelerators.

This repository contains the SystemVerilog RTL, C++, HLS (Intel FPGA OpenCL to wrap RTL code) and Python needed to reproduce the numerical results in

Facebook Research 373 Dec 31, 2022
Differentiable simulation for system identification and visuomotor control

gradsim gradSim: Differentiable simulation for system identification and visuomotor control gradSim is a unified differentiable rendering and multiphy

null 105 Dec 18, 2022
ROS-UGV-Control-Interface - Control interface which can be used in any UGV

ROS-UGV-Control-Interface Cam Closed: Cam Opened:

Ahmet Fatih Akcan 1 Nov 4, 2022
Hand Gesture Volume Control is AIML based project which uses image processing to control the volume of your Computer.

Hand Gesture Volume Control Modules There are basically three modules Handtracking Program Handtracking Module Volume Control Program Handtracking Pro

VITTAL 1 Jan 12, 2022
Applications using the GTN library and code to reproduce experiments in "Differentiable Weighted Finite-State Transducers"

gtn_applications An applications library using GTN. Current examples include: Offline handwriting recognition Automatic speech recognition Installing

Facebook Research 68 Dec 29, 2022
xitorch: differentiable scientific computing library

xitorch is a PyTorch-based library of differentiable functions and functionals that can be widely used in scientific computing applications as well as deep learning.

null 24 Apr 15, 2021
Open Source Differentiable Computer Vision Library for PyTorch

Kornia is a differentiable computer vision library for PyTorch. It consists of a set of routines and differentiable modules to solve generic computer

kornia 7.6k Jan 4, 2023
Differentiable scientific computing library

xitorch: differentiable scientific computing library xitorch is a PyTorch-based library of differentiable functions and functionals that can be widely

null 98 Dec 26, 2022
POT : Python Optimal Transport

POT: Python Optimal Transport This open source Python library provide several solvers for optimization problems related to Optimal Transport for signa

Python Optimal Transport 1.7k Dec 31, 2022
Official implementation of our CVPR2021 paper "OTA: Optimal Transport Assignment for Object Detection" in Pytorch.

OTA: Optimal Transport Assignment for Object Detection This project provides an implementation for our CVPR2021 paper "OTA: Optimal Transport Assignme

null 217 Jan 3, 2023
Exact Pareto Optimal solutions for preference based Multi-Objective Optimization

Exact Pareto Optimal solutions for preference based Multi-Objective Optimization

Debabrata Mahapatra 40 Dec 24, 2022
Code for paper "Vocabulary Learning via Optimal Transport for Neural Machine Translation"

**Codebase and data are uploaded in progress. ** VOLT(-py) is a vocabulary learning codebase that allows researchers and developers to automaticaly ge

null 416 Jan 9, 2023
A Planar RGB-D SLAM which utilizes Manhattan World structure to provide optimal camera pose trajectory while also providing a sparse reconstruction containing points, lines and planes, and a dense surfel-based reconstruction.

ManhattanSLAM Authors: Raza Yunus, Yanyan Li and Federico Tombari ManhattanSLAM is a real-time SLAM library for RGB-D cameras that computes the camera

null 117 Dec 28, 2022
Optimal Adaptive Allocation using Deep Reinforcement Learning in a Dose-Response Study

Optimal Adaptive Allocation using Deep Reinforcement Learning in a Dose-Response Study Supplementary Materials for Kentaro Matsuura, Junya Honda, Imad

Kentaro Matsuura 4 Nov 1, 2022
Developed an optimized algorithm which finds the most optimal path between 2 points in a 3D Maze using various AI search techniques like BFS, DFS, UCS, Greedy BFS and A*

Developed an optimized algorithm which finds the most optimal path between 2 points in a 3D Maze using various AI search techniques like BFS, DFS, UCS, Greedy BFS and A*. The algorithm was extremely optimal running in ~15s to ~30s for search spaces as big as 10000000 nodes where a set of 18 actions could be performed at each node in the 3D Maze.

null 1 Mar 28, 2022
A tool to analyze leveraged liquidity mining and find optimal option combination for hedging.

LP-Option-Hedging Description A Python program to analyze leveraged liquidity farming/mining and find the optimal option combination for hedging imper

Aureliano 18 Dec 19, 2022