On CPU, provides an updated jax.lax.scan implementation that can do filtering on blocks of keypoints, parallelized at all stages (over keypoints, over time points, etc.)
Includes a GPU parallel scan implementation for the Kalman filter -- again fully "jitted". This is used to do parameter estimation for the smoothing parameter. Code computes the Kalman filter + nonnegative log likelihood very fast.
Optax (jax) optimizer differentiates through all of the above implementations, allowing fast MLE computation of the smoothing parameter. Note: Can in principle also compute the observation noise data with minor modifications.
Key changes: