Follow up to #504 and #496. The initial implementation didn't actually stagger setting the 64-bit jax until first call. This implementation staggers the JIT call until the function is actually used. This shouldn't break JAX's cache since the function object in question never changes so its hash wont change and we still get all the accelerated code.
Follow up to #504 and #496. The initial implementation didn't actually stagger setting the 64-bit jax until first call. This implementation staggers the JIT call until the function is actually used. This shouldn't break JAX's cache since the function object in question never changes so its hash wont change and we still get all the accelerated code.