Closed soichiro-hattori closed 9 months ago
Good question! The errors are actually stochastic (re-running normally fixes them) and only occur on macOS (I don't think it has anything to do with the Python version). The error message is about value mismatches, not a TypeError
, so it's source isn't anything to do with the number of outputs. While grad
is only defined for scalar outputs, check_grads
is well defined for any number of outputs because it tests the JVP and VJP directly.
check_grads
does use random directions to check the gradients, but it looks like the random seed is fixed (here and here) so it's not clear why we would get stochastic failures.
In the short term our best bet might be to label that test with xfail on macOS, and continue testing it on linux.
Ah I see! I did see in the output that the error was caused by the ValueError
instead of a TypeError
so that was another point I was confused about.
I reran the test earlier and it failed again (I saw that it passed now) so I thought it may have been something non-stochastic.
Thanks! I'll close this for now.
The checks for PR #156 are failing due to some of the
test_grad
tests failing. It's a bit odd because if I run the tests locally they all pass and it only seems to be an issue with the python 3.11 tests.I'm not entirely sure about the root cause of the tests failing but I think one potential issue is related to the fact that the
kepler
function actually returns 2 values. Since JAX gradients are only defined for scalar-output functions, the following causes an error:I'm not sure why
check_grads
is actually passing for certain cases given thatjax.grad(kepler)
shouldn't ever work?@dfm: Should we return a single value from
kepler
?