google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.07k stars 813 forks source link

[JAX] Fix test failure in Trax due to upcoming change to JAX. #1680

Closed copybara-service[bot] closed 3 years ago

copybara-service[bot] commented 3 years ago

[JAX] Fix test failure in Trax due to upcoming change to JAX.

An upcoming change to JAX wraps jit decorators around more standard library functions, including division (/). jit-decorated functions will error if passed a large Python scalar integer that overflows the range of a int32 or int64. The workaround is to explicitly case the Python scalar to a specific type, here np.float32.