jaxls
is a library for nonlinear least squares in JAX.
We provide a factor graph interface for specifying and solving least squares problems. We accelerate optimization by analyzing the structure of graphs: repeated factor and variable types are vectorized, and the sparsity of adjacency in the graph is translated into sparse matrix operations.
Use cases are primarily in least squares problems that are (1) sparse and (2) inefficient to solve with gradient-based methods.
Currently supported:
- Automatic sparse Jacobians.
- Optimization on manifolds.
- Examples provided for SO(2), SO(3), SE(2), and SE(3).
- Nonlinear solvers: Levenberg-Marquardt and Gauss-Newton.
- Linear subproblem solvers:
- Sparse iterative with Conjugate Gradient.
- Preconditioning: block and point Jacobi.
- Inexact Newton via Eisenstat-Walker.
- Sparse direct with Cholesky / CHOLMOD, on CPU.
- Dense Cholesky for smaller problems.
- Sparse iterative with Conjugate Gradient.
For the first iteration of this library, written for IROS 2021, see
jaxfg. jaxls
is a rewrite that is faster
and easier to use. For additional references, see inspirations like
GTSAM, Ceres Solver,
minisam,
SwiftFusion,
g2o.
jaxls
supports python>=3.12
:
pip install git+https://github.com/brentyi/jaxls.git
Optional: CHOLMOD dependencies
By default, we use an iterative linear solver. This requires no extra dependencies. For some problems, like those with banded matrices, a direct solver can be much faster.
For Cholesky factorization via CHOLMOD, we rely on SuiteSparse:
# Option 1: via conda.
conda install conda-forge::suitesparse
# Option 2: via apt.
sudo apt install -y libsuitesparse-dev
You'll also need scikit-sparse:
pip install scikit-sparse
import jaxls
import jaxlie
Defining variables. Each variable is given an integer ID. They don't need to be contiguous.
pose_vars = [jaxls.SE2Var(0), jaxls.SE2Var(1)]
Defining factors. Factors are defined using a callable cost function and a set of arguments.
# Factors take two arguments:
# - A callable with signature `(jaxls.VarValues, *Args) -> jax.Array`.
# - A tuple of arguments: the type should be `tuple[*Args]`.
#
# All arguments should be PyTree structures. Variable types within the PyTree
# will be automatically detected.
factors = [
# Cost on pose 0.
jaxls.Factor(
lambda vals, var, init: (vals[var] @ init.inverse()).log(),
(pose_vars[0], jaxlie.SE2.from_xy_theta(0.0, 0.0, 0.0)),
),
# Cost on pose 1.
jaxls.Factor(
lambda vals, var, init: (vals[var] @ init.inverse()).log(),
(pose_vars[1], jaxlie.SE2.from_xy_theta(2.0, 0.0, 0.0)),
),
# Cost between poses.
jaxls.Factor(
lambda vals, var0, var1, delta: (
(vals[var0].inverse() @ vals[var1]) @ delta.inverse()
).log(),
(pose_vars[0], pose_vars[1], jaxlie.SE2.from_xy_theta(1.0, 0.0, 0.0)),
),
]
Factors with similar structure, like the first two in this example, will be vectorized under-the-hood.
Solving optimization problems. We can set up the optimization problem, solve it, and print the solutions:
graph = jaxls.FactorGraph.make(factors, pose_vars)
solution = graph.solve()
print("All solutions", solution)
print("Pose 0", solution[pose_vars[0]])
print("Pose 1", solution[pose_vars[1]])
There are many practical features that we don't currently support:
- GPU accelerated Cholesky factorization. (for CHOLMOD we wrap scikit-sparse, which runs on CPU only)
- Covariance estimation / marginalization.
- Incremental solves.
- Analytical Jacobians.