jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.17k stars 2.76k forks source link

check_grads improvements #7742

Open mattjj opened 3 years ago

mattjj commented 3 years ago

Ideas from @blakehechtman:

  1. make the default value of EPS in check_grads depend on dtype, rather than being a global. We were initially confused when we saw zeros in results, until Blake realized that EPS=1e-4 is smaller than bf16 precision. (We have some mechanisms for making tolerances and eps depend on dtype, but usually at the test case level. Putting it in the test utilities would make it more idiot-proof!)
  2. make the error message better than just showing x and y, i.e. indicate which is analytical and which is numerical
nicholasjng commented 3 years ago

Another idea from my side, though unrelated to tolerances: While working on a PR on auxiliary arguments for some jax.lax solvers, I ran into the problem of using check_grads with functions returning objects with integral values (bool, ints); in the _check_dtypes_match portion, which is used to check type equality between primals and tangents, for those, unfortunately _dtype(x) == _dtype(y) is False since their tangents are cast to jnp.float0 (cf. here and here), causing the test to fail.

I eventually took the lazy approach and just tested with floating type auxiliary values, but this might still be worth addressing for use in other contexts?