choderalab / pymbar

Python implementation of the multistate Bennett acceptance ratio (MBAR)
http://pymbar.readthedocs.io
MIT License
240 stars 93 forks source link

Correctly stagger JIT until first call #505

Closed Lnaden closed 1 year ago

Lnaden commented 1 year ago

Follow up to #504 and #496. The initial implementation didn't actually stagger setting the 64-bit jax until first call. This implementation staggers the JIT call until the function is actually used. This shouldn't break JAX's cache since the function object in question never changes so its hash wont change and we still get all the accelerated code.

codecov[bot] commented 1 year ago

Codecov Report

Merging #505 (14a16b6) into master (a5fa114) will decrease coverage by 0.05%. The diff coverage is 100.00%.