jaxfg - Factor graph-based nonlinear optimization library for JAX.

Overview

jaxfg

Factor graph-based nonlinear optimization library for JAX.

Applications include sensor fusion, control, planning, SLAM. Borrows heavily from a wide set of existing libraries, including: Ceres Solver, g2o, GTSAM, minisam, SwiftFusion.

Features:

  • Autodiff-powered (sparse) Jacobians.
  • Automatic batching of factor computations.
  • Out-of-the-box support for optimization on SO(2), SO(3), SE(2), and SE(3).
  • 100% implemented in Python!

Current limitations:

  • JIT compilation adds significant startup overhead. This could likely be optimized (for example, by specifying more analytical Jacobians) but is mostly unavoidable with JAX/XLA. Limits applications for systems that are online or require dynamic graph alterations.
  • Python >=3.7 only, due to features needed for generic types.

Installation

scikit-sparse require SuiteSparse:

sudo apt update
sudo apt install -y libsuitesparse-dev

Then, from your environment of choice:

git clone https://github.com/brentyi/jaxfg.git
cd jaxfg
pip install -e .

Example scripts

Toy pose graph optimization:

python scripts/pose_graph_simple.py

Pose graph optimization from .g2o files:

python scripts/pose_graph_g2o.py --help

To-do

  • Preliminary graph, variable, factor interfaces
  • Real vector variable types
  • Refactor into package
  • Nonlinear optimization for MAP inference
    • Conjugate gradient linear solver
    • CHOLMOD linear solver
      • Basic implementation. JIT-able, but no vmap, pmap, or autodiff support.
    • Gauss-Newton implementation
    • Termination criteria
    • Damped least squares
    • Dogleg
    • Inexact Newton steps
    • Revisit termination criteria
    • Reduce redundant code
    • Robust losses
  • Marginalization
    • Working prototype using sksparse/CHOLMOD
    • JAX implementation?
  • Validate g2o example
  • Performance
    • More intentional JIT compilation
    • Re-implement parallel factor computation
    • Vectorized linearization
    • Basic (Jacobi) CGLS preconditioning
  • Manifold optimization (mostly offloaded to jaxlie)
    • Basic interface
    • Manifold optimization on SO2
    • Manifold optimization on SE2
    • Manifold optimization on SO3
    • Manifold optimization on SE3
  • Usability + code health (low priority)
    • Basic cleanup/refactor
      • Better parallel factor interface
      • Separate out utils, lie group helpers
      • Put things in folders
    • Resolve typing errors
    • Cleanup/refactor (more)
    • Package cleanup: dependencies, etc
    • Add CI:
      • mypy
      • lint
      • build
      • coverage
    • More comprehensive tests
    • Clean up docstrings
Comments
  • Changing and indexing prior factors after making graph

    Changing and indexing prior factors after making graph

    Thanks for the library! I'm trying to update the mu of a prior factor based on the optimized pose of another graph's node. However I'm stuck figuring out how the rows of graph.factor_stacks[1].factor.mu.unit_complex_xy correspond to which nodes. Any help? Thanks

    opened by Nate711 5
  • A possible bug with the order in solution_assignments?

    A possible bug with the order in solution_assignments?

    Hi Brent,

    It seems that the solution_assignments is not consistent with the order of pose_variables in pose_graph_simple.py(https://github.com/brentyi/jaxfg/blob/f5204945bb6afa444810e6163e9a913bdbdd636b/scripts/pose_graph_simple.py). For the original code, the outputs of

    # Grab and print a single variable value at a time.
    print("First pose (jaxlie.SE2 object):")
    print(solution_assignments.get_value(pose_variables[0]))
    print()
    
    print("Second pose (jaxlie.SE2 object):")
    print(solution_assignments.get_value(pose_variables[1]))
    print()
    

    was

    First pose (jaxlie.SE2 object): 
    SE2(unit_complex=[1. 0.], xy=[0.33333 0.     ])
    
    Second pose (jaxlie.SE2 object):
    SE2(unit_complex=[1. 0.], xy=[1.66667 0.     ])
    

    However, if I change the code of creating the factor graph to (switch the order of the two PriorFactors)

    factors: List[jaxfg.core.FactorBase] = [
    jaxfg.geometry.PriorFactor.make(
            variable=pose_variables[1],
            mu=jaxlie.SE2.from_xy_theta(2.0, 0.0, 0.0),
            noise_model=jaxfg.noises.DiagonalGaussian(jnp.ones(3)),
        ),
        jaxfg.geometry.PriorFactor.make(
            variable=pose_variables[0],
            mu=jaxlie.SE2.from_xy_theta(0.0, 0.0, 0.0),
            noise_model=jaxfg.noises.DiagonalGaussian(jnp.ones(3)),
        ),
        jaxfg.geometry.BetweenFactor.make(
            variable_T_world_a=pose_variables[0],
            variable_T_world_b=pose_variables[1],
            T_a_b=jaxlie.SE2.from_xy_theta(1.0, 0.0, 0.0),
            noise_model=jaxfg.noises.DiagonalGaussian(jnp.ones(3)),
        ),
    ]
    

    Then, the outputs become

    First pose (jaxlie.SE2 object):
    SE2(unit_complex=[1. 0.], xy=[1.66667 0.     ])
    
    Second pose (jaxlie.SE2 object):
    SE2(unit_complex=[1. 0.], xy=[0.33333 0.     ])
    

    It seems that the order in the solution_assignments is determined by the order of the factors?

    opened by alecwangcq 3
  • Fix storage layout mismatch bugs

    Fix storage layout mismatch bugs

    Hey @brentyi this PR should fix: https://github.com/brentyi/jaxfg/issues/12 Not sure about sorting the full list of variables as sorted is not in-place afaik. If performance is an issue, we could scan over the full list first and extract all unique type(variable)-keys and sort them rather the full list. Let me know what you think :)

    bug 
    opened by SuperN1ck 3
  • Open source license?

    Open source license?

    Hi, this projects looks awesome! I definitely see some potential synergy/overlap with JAXopt: https://github.com/google/jaxopt

    Would you consider adding an open source license so others can use your code? GitHub has some good guidance here: https://docs.github.com/en/github/creating-cloning-and-archiving-repositories/creating-a-repository-on-github/licensing-a-repository

    Thanks!

    opened by shoyer 2
  • RealVectorVariable Class Issue

    RealVectorVariable Class Issue

    https://github.com/brentyi/jaxfg/blob/master/jaxfg/core/_variables.py#L151

    I tried to use the RealVectorVariable[dim] to instantiate a real vector variable. But I got an TypeError: Parameter list is too short. I fixed error by changing the code from

            class _RealVectorVariable(VariableBase[hints.Array]):
                @staticmethod
                @overrides
                @final
                def get_default_value() -> hints.Array:
                    return jnp.zeros(dim)
    

    to

            class _RealVectorVariable(VariableBase[hints.Array]):
                @classmethod
                @overrides
                @final
                def get_default_value(cls) -> hints.Array:
                    return jnp.zeros(dim)
    
    opened by alecwangcq 1
  • Update jax.partial to functools.partial

    Update jax.partial to functools.partial

    Hey @brentyi In the latest jax-version jax.partial was replaced in favor of functools.partial. Should be fixed by this PR https://github.com/google/jax/releases/tag/jax-v0.2.21

    opened by SuperN1ck 1
  • PyTree registration refactor

    PyTree registration refactor

    @SuperN1ck would be nice if you could skim through this and let me know if it makes sense!

    Thing that affects you: after this is merged, factor classes will need to be decorated with @register_pytree_dataclass*.

    Main change is to refactor the logic used for designating static dataclass fields. Allows us to:

    • Specify static fields using dataclasses.field metadata.
    • Unify registration logic used across the library. Factors used to do their own thing.

    *I may just pull all this dataclass code into its own library and do some renaming, it seems somewhat generally useful and could help reduce boilerplate in jaxlie as well.

    opened by brentyi 0
  • About the performance compare between other graph optimization framework.

    About the performance compare between other graph optimization framework.

    Hi, is there any comparison of performance between this library and others( such as ceres solver, g2o, gtsam)? And, since the jaxfg is based on Jax, can the jaxfg use cuda to significantly speed up optimization speed.?

    opened by wystephen 4
  • least squares examples

    least squares examples

    These are simple least squares examples of how to apply a factor graph and how to use this codebase. This would have helped me a lot in understanding both the concepts and implementation. I was a bit confused first what Lie groups have to do with factor graphs... :)

    opened by AvanDavad 2
  • SLAM example

    SLAM example

    Hello

    This project seems very interesting, thanks for sharing it.

    Could you please provide an example of a SLAM implementation using in the examples section using a popular dataset (KITTI or oxford dataset for example)?

    This would be a very helpful example to learn how to use the library.

    Thanks in advance

    opened by ali-robot 2
Owner
Brent Yi
Brent Yi
AP1 Transcription Factor Binding Site Prediction

A machine learning project that predicted binding sites of AP1 transcription factor, using ChIP-Seq data and local DNA shape information.

null 1 Jan 21, 2022
JMP is a Mixed Precision library for JAX.

Mixed precision training [0] is a technique that mixes the use of full and half precision floating point numbers during training to reduce the memory bandwidth requirements and improve the computational efficiency of a given model.

DeepMind 108 Dec 31, 2022
ML Optimizers from scratch using JAX

Toy implementations of some popular ML optimizers using Python/JAX

Shreyansh Singh 38 Jul 29, 2022
DeepSpeed is a deep learning optimization library that makes distributed training easy, efficient, and effective.

DeepSpeed is a deep learning optimization library that makes distributed training easy, efficient, and effective. 10x Larger Models 10x Faster Trainin

Microsoft 8.4k Dec 30, 2022
MooGBT is a library for Multi-objective optimization in Gradient Boosted Trees.

MooGBT is a library for Multi-objective optimization in Gradient Boosted Trees. MooGBT optimizes for multiple objectives by defining constraints on sub-objective(s) along with a primary objective. The constraints are defined as upper bounds on sub-objective loss function. MooGBT uses a Augmented Lagrangian(AL) based constrained optimization framework with Gradient Boosted Trees, to optimize for multiple objectives.

Swiggy 66 Dec 6, 2022
Bayesian optimization based on Gaussian processes (BO-GP) for CFD simulations.

BO-GP Bayesian optimization based on Gaussian processes (BO-GP) for CFD simulations. The BO-GP codes are developed using GPy and GPyOpt. The optimizer

KTH Mechanics 8 Mar 31, 2022
CS 7301: Spring 2021 Course on Advanced Topics in Optimization in Machine Learning

CS 7301: Spring 2021 Course on Advanced Topics in Optimization in Machine Learning

Rishabh Iyer 141 Nov 10, 2022
Bonsai: Gradient Boosted Trees + Bayesian Optimization

Bonsai is a wrapper for the XGBoost and Catboost model training pipelines that leverages Bayesian optimization for computationally efficient hyperparameter tuning.

null 24 Oct 27, 2022
Pyomo is an object-oriented algebraic modeling language in Python for structured optimization problems.

Pyomo is a Python-based open-source software package that supports a diverse set of optimization capabilities for formulating and analyzing optimization models. Pyomo can be used to define symbolic problems, create concrete problem instances, and solve these instances with standard solvers.

Pyomo 1.4k Dec 28, 2022
This is an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch

This is an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch. It uses a simple TestEnvironment to test the algorithm

Martin Huber 59 Dec 9, 2022
A Python step-by-step primer for Machine Learning and Optimization

early-ML Presentation General Machine Learning tutorials A Python step-by-step primer for Machine Learning and Optimization This github repository gat

Dimitri Bettebghor 8 Dec 1, 2022
Tools for mathematical optimization region

Tools for mathematical optimization region

林景 15 Nov 30, 2022
Implementation of linesearch Optimization Algorithms in Python

Nonlinear Optimization Algorithms During my time as Scientific Assistant at the Karlsruhe Institute of Technology (Germany) I implemented various Opti

Paul 3 Dec 6, 2022
Little Ball of Fur - A graph sampling extension library for NetworKit and NetworkX (CIKM 2020)

Little Ball of Fur is a graph sampling extension library for Python. Please look at the Documentation, relevant Paper, Promo video and External Resour

Benedek Rozemberczki 619 Dec 14, 2022
A framework for building (and incrementally growing) graph-based data structures used in hierarchical or DAG-structured clustering and nearest neighbor search

A framework for building (and incrementally growing) graph-based data structures used in hierarchical or DAG-structured clustering and nearest neighbor search

Nicholas Monath 31 Nov 3, 2022
PLUR is a collection of source code datasets suitable for graph-based machine learning.

PLUR (Programming-Language Understanding and Repair) is a collection of source code datasets suitable for graph-based machine learning. We provide scripts for downloading, processing, and loading the datasets. This is done by offering a unified API and data structures for all datasets.

Google Research 76 Nov 25, 2022
This repo includes some graph-based CTR prediction models and other representative baselines.

Graph-based CTR prediction This is a repository designed for graph-based CTR prediction methods, it includes our graph-based CTR prediction methods: F

Big Data and Multi-modal Computing Group, CRIPAC 47 Dec 30, 2022
An open source framework that provides a simple, universal API for building distributed applications. Ray is packaged with RLlib, a scalable reinforcement learning library, and Tune, a scalable hyperparameter tuning library.

Ray provides a simple, universal API for building distributed applications. Ray is packaged with the following libraries for accelerating machine lear

null 23.3k Dec 31, 2022
LibTraffic is a unified, flexible and comprehensive traffic prediction library based on PyTorch

LibTraffic is a unified, flexible and comprehensive traffic prediction library, which provides researchers with a credibly experimental tool and a convenient development framework. Our library is implemented based on PyTorch, and includes all the necessary steps or components related to traffic prediction into a systematic pipeline.

null 432 Jan 5, 2023