qiskit-community / qiskit-dynamics

Tools for building and solving models of quantum systems in Qiskit
https://qiskit-community.github.io/qiskit-dynamics/
Apache License 2.0
106 stars 60 forks source link

Fix JAX version of t_span and t_eval merging functions #153

Closed DanPuzzuoli closed 2 years ago

DanPuzzuoli commented 2 years ago

Summary

Closes #147

Fixes the merge_t_args_jax and trim_t_results_jax functions. I think this is a final comprehensive solution to the problem that led to #147.

Details and comments

The bug stems from the following annoying technical issue that we've dealt with for a while:

I thought I had fixed this problem in #125: where the nan appear in the solution if duplicates are present is predictable, so trim_t_results_jax was written to avoid these values. However, #147 was created after discovering that this change made the calculation of gradients of solutions return nan. I think this is because the gradient computation of odeint will always utilize all intermediate computed solution values, so even though we were removing the nan's, the gradient computation was still hitting them, which is something we couldn't avoid. This PR fixes this by changing the "merging strategy" of merge_t_args and trim_t_results_jax. Rather than just appending t_span to the ends of t_eval, we append them, then, conditional on if the endpoints are equal, modify the resulting combined array so that there are no duplicates, then correctly choosing the right entry when trimming the results. Importantly, this will not fundamentally impact the solver behaviour.