choderalab / pymbar

Python implementation of the multistate Bennett acceptance ratio (MBAR)
http://pymbar.readthedocs.io
MIT License
240 stars 93 forks source link

Warn about JAX bitsize changes #504

Closed Lnaden closed 1 year ago

Lnaden commented 1 year ago

Partial solution to #496

This PR changes the timing of setting JAX 64-bit mode to right at first call of a JIT'd function rather than on import. A warning is issued if 32-bit mode JAX is detected on import, and a second warning is issued when the mode is first toggled.

Code functionality is overall not changed.

The JIT or Passthrough decorator is wrapped inside a different function to check the config on call.

codecov[bot] commented 1 year ago

Codecov Report

Merging #504 (dc29e7d) into master (dcc6b6d) will decrease coverage by 0.01%. The diff coverage is 100.00%.