jaxfg
![codecov](https://codecov.io/gh/brentyi/jaxfg/branch/master/graph/badge.svg?token=RNJB7EFC8T)
jaxfg
is a factor graph-based nonlinear least squares library for JAX.
Typical applications include sensor fusion, SLAM, bundle adjustment, optimal
control.
The premise: we provide a high-level interface for defining probability
densities as factor graphs. MAP inference reduces to nonlinear optimization,
which we accelerate by analyzing the structure of the graph. Repeated factor and
variable types have operations vectorized, and the sparsity of graph connections
is translated into sparse matrix operations.
Features:
- Autodiff-powered sparse Jacobians.
- Automatic vectorization for repeated factor and variable types.
- Manifold definition interface, with implementations provided for SO(2), SE(2),
SO(3), and SE(3) Lie groups.
- Support for standard JAX function transformations:
jit
, vmap
, pmap
,
grad
, etc.
- Nonlinear optimizers: Gauss-Newton, Levenberg-Marquardt, Dogleg.
- Sparse linear solvers: conjugate gradient (Jacobi-preconditioned), sparse
Cholesky (via CHOLMOD).
This library is released as part of our IROS 2021 paper (more info in our core
experiment repository here) and borrows
heavily from a wide set of existing libraries, including
GTSAM, Ceres Solver,
minisam,
SwiftFusion, and
g2o. For technical background and
concepts, GTSAM has a
great set of tutorials.
Installation
scikit-sparse
require SuiteSparse:
sudo apt update
sudo apt install -y libsuitesparse-dev
Then, from your environment of choice:
git clone https://github.com/brentyi/jaxfg.git
cd jaxfg
pip install -e .
Example scripts
Toy pose graph optimization:
python scripts/pose_graph_simple.py
Pose graph optimization from .g2o
files:
python scripts/pose_graph_g2o.py # For options, pass in a --help flag
![](https://github.com/brentyi/jaxfg/raw/master/scripts/data/optimized_sphere2500.png)
Development
If you're interested in extending this library to define your own factor graphs,
we'd recommend first familiarizing yourself with:
- Pytrees in JAX:
https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html
- Python dataclasses: https://docs.python.org/3/library/dataclasses.html
- We currently take a "make everything a dataclass" philosophy for software
engineering in this library. This is convenient for several reasons, but
notably makes it easy for objects to be registered as pytree nodes. See
jax_dataclasses
for details
on this.
- Type annotations: https://docs.python.org/3/library/typing.html
- We rely on generics (
typing.Generic
and typing.TypeVar
) particularly
heavily. If you're familiar with C++ this should come very naturally
(~templates).
- Explicit decorators for overrides/inheritance:
https://github.com/mkorpela/overrides
- The
@overrides
and @final
decorators signal which methods are being
and/or shouldn't be overridden. The same goes for @abc.abstractmethod
.
From there, we have a few references for defining your own factor graphs,
factors, and manifolds:
Current limitations
- In XLA, JIT compilation needs to happen for each unique set of input shapes.
Modifying graph structures can thus introduce significant re-compilation
overheads; this can restrict applications that are dynamic or online.
- Our marginalization implementation is not very good.
To-do
This library's still in development mode! Here's our TODO list:
- [x] Preliminary graph, variable, factor interfaces
- [x] Real vector variable types
- [x] Refactor into package
- [x] Nonlinear optimization for MAP inference
- [x] Conjugate gradient linear solver
- [x] CHOLMOD linear solver
- [x] Basic implementation. JIT-able, but no vmap, pmap, or autodiff
support.
- [ ] Custom VJP rule? vmap support?
- [x] Gauss-Newton implementation
- [x] Termination criteria
- [x] Damped least squares
- [x] Dogleg
- [x] Inexact Newton steps
- [x] Revisit termination criteria
- [x] Reduce redundant code
- [ ] Robust losses
- [x] Marginalization
- [x] Prototype using sksparse/CHOLMOD (works but fairly slow)
- [ ] JAX implementation?
- [x] Validate g2o example
- [x] Performance
- [x] More intentional JIT compilation
- [x] Re-implement parallel factor computation
- [x] Vectorized linearization
- [x] Basic (Jacobi) CGLS preconditioning
- [x] Manifold optimization (mostly offloaded to
jaxlie)
- [x] Basic interface
- [x] Manifold optimization on SO2
- [x] Manifold optimization on SE2
- [x] Manifold optimization on SO3
- [x] Manifold optimization on SE3
- [ ] Usability + code health (low priority)
- [x] Basic cleanup/refactor
- [x] Better parallel factor interface
- [x] Separate out utils, lie group helpers
- [x] Put things in folders
- [x] Resolve typing errors
- [x] Cleanup/refactor (more)
- [x] Package cleanup: dependencies, etc
- [x] Add CI:
- [x] mypy
- [x] lint
- [x] build
- [ ] coverage
- [ ] More comprehensive tests
- [ ] Clean up docstrings
- [ ] New name