[JAX] Fix test failure in Trax due to upcoming change to JAX.
An upcoming change to JAX wraps jit decorators around more standard library functions, including division (/). jit-decorated functions will error if passed a large Python scalar integer that overflows the range of a int32 or int64. The workaround is to explicitly case the Python scalar to a specific type, here np.float32.
[JAX] Fix test failure in Trax due to upcoming change to JAX.
An upcoming change to JAX wraps
jit
decorators around more standard library functions, including division (/). jit-decorated functions will error if passed a large Python scalar integer that overflows the range of a int32 or int64. The workaround is to explicitly case the Python scalar to a specific type, here np.float32.