exoplanet-dev / jaxoplanet

Astronomical time series analysis with JAX
https://jax.exoplanet.codes
MIT License
41 stars 12 forks source link

`core.kepler` should return a single value if we're checking the gradients #158

Closed soichiro-hattori closed 9 months ago

soichiro-hattori commented 9 months ago

The checks for PR #156 are failing due to some of the test_grad tests failing. It's a bit odd because if I run the tests locally they all pass and it only seems to be an issue with the python 3.11 tests.

I'm not entirely sure about the root cause of the tests failing but I think one potential issue is related to the fact that the kepler function actually returns 2 values. Since JAX gradients are only defined for scalar-output functions, the following causes an error:

jax.grad(kepler)(m, e)  # for some chosen values for the mean anomaly and eccentricity.
TypeError: Gradient only defined for scalar-output functions. Output was (Array(0.96160406, dtype=float32), Array(0.27444044, dtype=float32))

I'm not sure why check_grads is actually passing for certain cases given that jax.grad(kepler) shouldn't ever work?

@dfm: Should we return a single value from kepler?

dfm commented 9 months ago

Good question! The errors are actually stochastic (re-running normally fixes them) and only occur on macOS (I don't think it has anything to do with the Python version). The error message is about value mismatches, not a TypeError, so it's source isn't anything to do with the number of outputs. While grad is only defined for scalar outputs, check_grads is well defined for any number of outputs because it tests the JVP and VJP directly.

check_grads does use random directions to check the gradients, but it looks like the random seed is fixed (here and here) so it's not clear why we would get stochastic failures.

In the short term our best bet might be to label that test with xfail on macOS, and continue testing it on linux.

soichiro-hattori commented 9 months ago

Ah I see! I did see in the output that the error was caused by the ValueError instead of a TypeError so that was another point I was confused about.

I reran the test earlier and it failed again (I saw that it passed now) so I thought it may have been something non-stochastic.

Thanks! I'll close this for now.