JAX numpy seems to be slower than numpy unless you also utilize JIT or offload to an accelerator. When JAX is in use, the code base will convert a numpy array to a JAX numpy array prior to calling a JIT function, but doesn't convert the array back once the JIT function returns. An investigation should be done to see if a conversion back introduces a speed improvement over the current process.
JAX numpy seems to be slower than numpy unless you also utilize JIT or offload to an accelerator. When JAX is in use, the code base will convert a numpy array to a JAX numpy array prior to calling a JIT function, but doesn't convert the array back once the JIT function returns. An investigation should be done to see if a conversion back introduces a speed improvement over the current process.