This PR changes the timing of setting JAX 64-bit mode to right at first call of a JIT'd function rather than on import. A warning is issued if 32-bit mode JAX is detected on import, and a second warning is issued when the mode is first toggled.
Code functionality is overall not changed.
The JIT or Passthrough decorator is wrapped inside a different function to check the config on call.
Partial solution to #496
This PR changes the timing of setting JAX 64-bit mode to right at first call of a JIT'd function rather than on import. A warning is issued if 32-bit mode JAX is detected on import, and a second warning is issued when the mode is first toggled.
Code functionality is overall not changed.
The JIT or Passthrough decorator is wrapped inside a different function to check the config on call.