facebookresearch / theseus

A library for differentiable nonlinear optimization
MIT License
1.78k stars 128 forks source link

Faster log map implementation #627

Closed luisenp closed 11 months ago

luisenp commented 1 year ago

The current autograd operators in torchlie.functional result in duplicated computation for function calls such as

res = torchlie.functional.SO3.log(g, jacobians=jlist)

because these are computed as follows:

jac, _= _jlog_impl(group)
res = _log_autograd_fn(group) # repeats all the computation done in `_op_impl`
return res, jac

This PR adds a wrapper that reuses the computation, while preserving the custom backward implementation, as follows

jac, res_value = _jlog_impl(group)
res_tensor = _log_passthrough_fn(group, res_value, jac[0])
return res_tensor, jac

Benchmarking with this script shows improvements of 15-23% runtime when backpropagating through local cost functions in Theseus. I haven't written any dedicated unit tests, but both versions result in the same loss history in the script above, so things look correct.