trajax
A Python library for differentiable optimal control on accelerators.
Trajax builds on JAX and hence code written with Trajax supports JAX's transformations. In particular, Trajax's solvers:
- Are automatically efficiently differentiable, via
jax.grad
. - Scale up to parallel instances via
jax.vmap
andjax.pmap
. - Can run on CPUs, GPUs, and TPUs without code changes, and support end-to-end compilation with
jax.jit
. - Are made available from Python, written with NumPy.
In Trajax, differentiation through the solution of a trajectory optimization problem is done more efficiently than by differentiating the solver implementation directly. Specifically, Trajax defines custom differentiation routines for its solvers. It registers these with JAX so that they are picked up whenever using JAX's autodiff features (e.g. jax.grad
) to differentiate functions that call a Trajax solver.
This is a research project, not an official Google product.
Trajax is currently a work in progress, maintained by a few individuals at Google Research. While we are actively using Trajax in our own research projects, expect there to be bugs and rough edges compared to commercially available solvers.
Trajectory optimization and optimal control
We consider classical optimal control tasks concerning optimizing trajectories of a given discrete time dynamical system by solving the following problem. Given a cost function c
, dynamics function f
, and initial state x0
, the goal is to compute:
argmin(lambda X, U: sum(c(X[t], U[t], t) for t in range(T)) + c_final(X[T]))
subject to the constraint that X[0] == x0
and that:
all(X[t + 1] == f(X[t], U[t], t) for t in range(T))
There are many resources for more on trajectory optimization, including Dynamic Programming and Optimal Control by Dimitri Bertsekas and Underactuated Robotics by Russ Tedrake.
API
In describing the API, it will be useful to abbreviate a JAX/NumPy floating point ndarray of shape (a, b, …)
as a type denoted F[a, b, …]
. Assume n
is the state dimension, d
is the control dimension, and T
is the time horizon.
Problem setup convention/signature
Setting up a problem requires writing two functions, cost and dynamics, with type signatures:
cost(state: F[n], action: F[d], time_step: int) : float
dynamics(state: F[n], action: F[d], time_step: int) : F[n]
Note that even if a dimension n
or d
is 1, the corresponding state or action representation is still a rank-1 ndarray (i.e. a vector, of length 1).
Because Trajax uses JAX, the cost
and dynamics
functions must be written in a functional programming style as required by JAX. See the JAX readme for details on writing JAX-friendly functional code. By and large, functions that have no side effects and that use jax.numpy
in place of numpy
are likely to work.
Solvers
If we abbreviate the type of the above two functions as CostFn
and DynamicsFn
, then our solvers have the following type signature prefix in common:
solver(cost: CostFn, dynamics: DynamicsFn, initial_state: F[n], initial_actions: F[T, d], *solver_args, **solver_kwargs): SolverOutput
SolverOutput
is a tuple of (F[T + 1, n], F[T, d], float, *solver_outputs)
. The first three tuple components represent the optimal state trajectory, optimal control sequence, and the optimal objective value achieved, respectively. The remaining *solver_outputs
are specific to the particular solver (such as number of iterations, norm of the final gradient, etc.).
There are currently four solvers provided: ilqr
, scipy_minimize
, cem
, and random_shooting
. Each extends the signatures above with solver-specific arguments and output values. Details are provided in each solver function's docstring.
Underlying the ilqr
implementation is a time-varying LQR routine, which solves a special case of the above problem, where costs are convex quadratic and dynamics are affine. To capture this, both are represented as matrices. This routine is also made available as tvlqr
.
Objectives
One might want to write a custom solver, or work with an objective function for any other reason. To that end, Trajax offers the optimal control objective in the form of an API function:
objective(cost: CostFn, dynamics: DynamicsFn, initial_state: F[n], actions: F[T, d]): float
Combining this function with JAX's autodiff capabilities offers, for example, a starting point for writing a first-order custom solver. For example:
def improve_controls(cost, dynamics, U, x0, eta, num_iters):
grad_fn = jax.grad(trajax.objective, argnums=(2,))
for i in range(num_iters):
U = U - eta * grad_fn(cost, dynamics, U, x0)
return U
The solvers provided by Trajax are actually built around this objective
function. For instance, the scipy_minimize
solver simply calls scipy.minimize.minimize
with the gradient and Hessian-vector product functions derived from objective
using jax.grad
and jax.hessian
.
Limitations
Just as Trajax inherits the autodiff, compilation, and parallelism features of JAX, it also inherits its corresponding limitations. Functions such as the cost and dynamics given to a solver must be written using jax.numpy
in place of standard numpy
, and must conform to a functional style; see the JAX readme. Due to the complexity of trajectory optimizer implementations, initial compilation times can be long.