mjo22 / cryojax

Cryo electron microscopy image simulation and analysis built on JAX.
https://mjo22.github.io/cryojax/
GNU Lesser General Public License v2.1
29 stars 12 forks source link

What is the right way to estimate rotations by autograd? #61

Open mjo22 opened 10 months ago

mjo22 commented 10 months ago

The package jaxlie, a dependency of cryojax seems to have a way of doing pose estimation: example here, https://github.com/brentyi/jaxlie/blob/master/examples/se3_optimization.py, jax.grad wrappers here: https://github.com/brentyi/jaxlie/blob/master/jaxlie/manifold/_backprop.py.

We should have a Pose subclass that is a good parameterization for gradient descent. I think this should involve parameterizing in su(2), and exponentiating to get differential elements of SO(3). This seems to be what they do.

mjo22 commented 10 months ago

See here, from the jaxlie developer: https://github.com/brentyi/jaxlie/issues/14

brentyi commented 10 months ago

As a note, my personal experience is that there'll be some runtime differences between different methods but it doesn't really matter if you're doing first-order optimization, as long as what you're doing is numerically stable. Even optimizing directly with respect to quaternions and just normalizing after each steps is good enough.

If you want to explore Newton method variants, the tangent-space stuff becomes much more important.

mjo22 commented 10 months ago

I see. Thank you for the insight @brentyi! This is very interesting. Why is this? Do second derivatives when optimizing with respect to quaternions tend to be less numerically stable? Or would there be some other issue, like the error surface not being very smooth? I’m also curious if there is a geometrical explanation why this should be true. Haven’t gotten the chance yet to go through the reference you sent my way…

brentyi commented 10 months ago

My intuition here is that it mostly comes down to how big the steps your optimizer is taking are. With gradient descent the steps tend to be fairly small, so a naive handling of the parameterization won't result in large deviations from the manifold after each step. This is why projected gradient descent can converge nicely, even in the face of noisy gradients.

On the other hand, something like Gauss-Newton will try to jump to a global optima (of a linear approximation of your problem) at each iteration. This produces larger steps that might take you very far away from your rotation manifold if you just treat the parameters as Euclidean; it's also much more sensitive to the quality of Jacobian/Hessian estimates, which should improve when you take the topology into account.