brentyi / jaxls

Sparse nonlinear least squares in JAX
MIT License
179 stars 12 forks source link

Add numerical constrained noise model #2

Closed SuperN1ck closed 3 years ago

SuperN1ck commented 3 years ago

This PR adds a very simple numerical constrained noise model by replacing infinite values in the precision diagonal array with a very big number.

SuperN1ck commented 3 years ago

Thanks for the feedback Brent! Let me answer your questions :)

This is meant to be a port of gtsam's Constrained noise model, right? Do you think you could comment on the implementation differences between how infinite precision values are handled here vs in their implementation?

It seems like they have quite a bit of extra logic (for example, in whiten), but it's not immediately obvious why.

Yes it's inspired by GTSAMs implementation. In general the flow of the data can be roughly described by DistanceMeasure(Whiten(ResidualVector)), right? When looking at the special case for possible infinite precision values (aka a hard constraint), whitening the residual vector would also lead to infinite precision, resulting in vanishing gradients etc.

GTSAM handles that problem two fold. First, in their whitening process they leave the residual ai untouched c(i) = (bi==0.0) ? ai : ai/bi and therefore do not produce any infinity values. Second, in their distance measure they multiply the untouched residual vector by the penalty mu_[i] in w[i] = v[i] * sqrt(mu_[i]); (where v[i] is c[i] in the previous code snippet).

Since in jaxfg we currently do not do a noise specific distance calculation but rather take the global Mahalanobis distance on the whole residual vector (also I currently see no reason to change that), I combined both in jnp.where.

With how things are currently set up, I wonder if this logic could just be folded directly into the DiagonalGaussian class; it looks like it could be written as a simple upper bound on the precision values.

Sounds good to me! If you are fine with not having an explicit class, I would suggest in make_from_covariance we could either check for infinite values or even directly for very small elements in diagonal and in that case rather than taking the sqrt-root we could assign the penalty directly to the respective position in sqrt_precision_diagonal? Another benefit of moving the changes directly in DiagonalGaussian is also to prevent users from running into hard to debug errors (e.g. when their covariance is too small).

brentyi commented 3 years ago

Closing as per conversation -- maybe revisit later!