Open mattjj opened 3 years ago
Another idea from my side, though unrelated to tolerances: While working on a PR on auxiliary arguments for some jax.lax
solvers, I ran into the problem of using check_grads
with functions returning objects with integral values (bool, ints); in the _check_dtypes_match
portion, which is used to check type equality between primals and tangents, for those, unfortunately _dtype(x) == _dtype(y)
is False
since their tangents are cast to jnp.float0
(cf. here and here), causing the test to fail.
I eventually took the lazy approach and just tested with floating type auxiliary values, but this might still be worth addressing for use in other contexts?
Ideas from @blakehechtman:
check_grads
depend on dtype, rather than being a global. We were initially confused when we saw zeros in results, until Blake realized that EPS=1e-4 is smaller than bf16 precision. (We have some mechanisms for making tolerances and eps depend on dtype, but usually at the test case level. Putting it in the test utilities would make it more idiot-proof!)x
andy
, i.e. indicate which is analytical and which is numerical