exoplanet-dev / jaxoplanet

Astronomical time series analysis with JAX
https://jax.exoplanet.codes
MIT License
38 stars 12 forks source link

feat: rotation matrices using s2fft #212

Closed lgrcia closed 4 weeks ago

lgrcia commented 2 months ago

Using the s2fft Python package to compute the Wigner D-matrices used to rotate the spherical harmonics. See also #140

lgrcia commented 2 months ago

To solve a numpy import issue in s2fft (see s2fft#206) I had to impose numpy<2.0. Not sure we want to stick to that.

lgrcia commented 1 month ago

I can't figure why the macos-python3.11 test is hanging on until being timed out... Any help would be welcome!

dfm commented 1 month ago

I can't figure why the macos-python3.11 test is hanging on until being timed out

Strange! You could try removing the -n auto from the pytest command? Does it work ok locally on your mac?

dfm commented 1 month ago

I think maybe github actions were actually just having issues that day. I've tried re-running the job. Let's see if that works!

lgrcia commented 1 month ago

Re-running the job failed. I can reproduce the problem locally but I can't locate from which test the problem is coming. Individual tests are passing locally but freeze when run together. Could it be a memory issue? I'm clueless...

The major change is enforcing numpy<2.0.

dfm commented 1 month ago

@lgrcia — I tried turning off parallel execution of the tests and that changed the behavior to just hang indefinitely when running the test_light_curves_orders test (I think). This isn't failing on the main branch so it must have something to do with s2fft. I can't imagine the issue has to do with numpy (although pinning it <2.0 seems like a Bad Idea™ - hopefully they can fix the issues soon!). It might be worth trying some different versions of JAX (e.g. <=0.4.31) and s2fft to see if you can narrow down to a combination that works.

dfm commented 1 month ago

The issue here must be related to s2fft! @lgrcia, can you try to work out which inputs we're passing in the test that hangs to isolate exactly what s2fft call we're executing. It would be interesting to know if we can get the same hang using just s2fft and no jaxoplanet. In that case we can report upstream and see what they say.

lgrcia commented 1 month ago

I think I identified where the problem is from.

Here is a way to reproduce it on macos ```python import numpy as np from jaxoplanet.experimental.starry.rotation import dot_rotation_matrix l_max = 5 theta = 0 ident = np.eye(l_max**2 + 2 * l_max + 1) expected = dot_rotation_matrix(l_max, 0.0, 0.0, 1.0, theta)(ident) calc = dot_rotation_matrix(l_max, None, None, 1.0, theta)(ident) ``` This runs ok. But then, when `l_max` is changed in the same runtime: ```python l_max = 6 ident = np.eye(l_max**2 + 2 * l_max + 1) expected = dot_rotation_matrix(l_max, 0.0, 0.0, 1.0, theta)(ident) ``` it freezes. I think it has to do in how I (or s2fft) combine the static arguments in the jitted functions from `s2fft.utils.rotation` and `s2fft.utils` (see also the jitted signature from `jaxoplanet.experimental.starry.s2fft_rotation.py:compute_rotation_matrices`). But I honestly don't totally get why it is only a problem on macos.

I don't really understand why it behaves like this but a workaround for me is to avoid decorating the s2fft rotation functions with jit. So I copied all required functions (we only need 100 lines of python from s2fft) and removed the s2fft dependency, for now.

I'm down to understand the problem better before reintroducing s2fft as a dependency.

lgrcia commented 1 month ago

The issue here must be related to s2fft! @lgrcia, can you try to work out which inputs we're passing in the test that hangs to isolate exactly what s2fft call we're executing. It would be interesting to know if we can get the same hang using just s2fft and no jaxoplanet. In that case we can report upstream and see what they say.

@dfm, here is a way to reproduce only with s2fft:

import jax
from functools import partial
from s2fft.utils.rotation import generate_rotate_dls

@partial(jax.jit, static_argnums=(0,))
def f(deg, alpha):
    return generate_rotate_dls(deg, alpha)

_ = f(5, 0.0)  # this executes fine
_ = f(10, 0.0)  # this freezes

I might open an issue but I think this is not a proper use of this function given the static args (see https://github.com/astro-informatics/s2fft/blob/main/s2fft/utils/rotation.py#L75). Any idea why this would happen?

I understand that each test run in separate python instances should pass. So could the issue be due to how pytest runners are dispatched on macOS?

dfm commented 1 month ago

It's fascinating to me that that happens and that your solution works! I don't see any reason why the beta parameter should be labelled as static, but it also seems like it should crash to nest these static_argnums incompatibly like this...

Regardless: I think this is a good "fix"!

lgrcia commented 4 weeks ago

@dfm, do you think we are ready to merge?

lgrcia commented 4 weeks ago

Just for reference, the s2fft addition was merged by mistake in #225. Thanks for all the reviews on this!!