astro-informatics / s2wav

Differentiable and accelerated wavelet transform on the sphere with JAX
https://astro-informatics.github.io/s2wav/
MIT License
12 stars 0 forks source link

s2wav.analysis fails for Healpix map input with AssertionError #84

Open 1cosmologist opened 1 month ago

1cosmologist commented 1 month ago

I am trying to compute directional wavelet transformation of a Healpix map. I have tried using both s2wav.analysis, s2wav.wavelet.flm_to_analysis (with map to flm separately computed with s2fft). I am encountering an AssertionError.

Minimal example:

nside = 128
lmax = 2 * nside 
N = 3

hpx_map = np.ones((12*nside**2,))

filter_bank = sw.filters.filters_directional_vectorised(lmax, N)
wavelet_coeffs, scaling_coeffs = sw.analysis(hpx_map, lmax, N, nside=nside, filters=filter_bank, sampling='healpix')

Fails with the following error message:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[116], line 8
      5 hpx_map = np.ones((12*nside**2,))
      7 filter_bank = sw.filters.filters_directional_vectorised(lmax, N)
----> 8 wavelet_coeffs, scaling_coeffs = sw.analysis(hpx_map, lmax, N, nside=nside, filters=filter_bank, sampling='healpix')

    [... skipping hidden 11 frame]

File /pscratch/sd/s/shamikg/cmbenv/master-0.0.1/conda/lib/python3.10/site-packages/s2wav/transforms/wavelet.py:189, in analysis(f, L, N, J_min, lam, spin, sampling, nside, reality, filters, precomps)
    174     Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True)
    175     f_wav_lmn[j - J_min] = (
    176         f_wav_lmn[j - J_min]
    177         .at[::2, L0j:]
   (...)
    185         )
    186     )
    188     f_wav.append(
--> 189         s2fft.wigner.inverse_jax(
    190             f_wav_lmn[j - J_min],
    191             Lj,
    192             Nj,
    193             nside,
    194             sampling,
    195             reality,
    196             precomps[j - J_min],
    197             L0j,
    198         )
    199     )
    201 # Project all harmonic coefficients for each lm onto scaling coefficients
    202 phi = filters[1][:Ls] * jnp.sqrt(4 * jnp.pi / (2 * jnp.arange(Ls) + 1))

    [... skipping hidden 11 frame]

File /pscratch/sd/s/shamikg/cmbenv/master-0.0.1/conda/lib/python3.10/site-packages/s2fft/transforms/wigner.py:257, in inverse_jax(flmn, L, N, nside, sampling, reality, precomps, L_lower)
    251     precomps = [p0, p1, p2, p3, p4]
    252     return (-1) ** jnp.abs(spin) * s2fft.inverse_jax(
    253         flm, L, -spin, nside, sampling, False, precomps, False, L_lower
    254     )
    256 fban = fban.at[N - 1 + n_start_ind :].set(
--> 257     vmap(
    258         partial(func, p2=precomps[2][0], p3=precomps[3][0], p4=precomps[4][0]),
    259         in_axes=(0, 0, 0, 0),
    260     )(flmn[N - 1 + n_start_ind :], spins, precomps[0], precomps[1])
    261 )
    262 if reality:
    263     f = jnp.fft.irfft(fban[N - 1 :], 2 * N - 1, axis=0, norm=\"forward\")

    [... skipping hidden 3 frame]

File /pscratch/sd/s/shamikg/cmbenv/master-0.0.1/conda/lib/python3.10/site-packages/s2fft/transforms/wigner.py:252, in inverse_jax.<locals>.func(flm, spin, p0, p1, p2, p3, p4)
    250 def func(flm, spin, p0, p1, p2, p3, p4):
    251     precomps = [p0, p1, p2, p3, p4]
--> 252     return (-1) ** jnp.abs(spin) * s2fft.inverse_jax(
    253         flm, L, -spin, nside, sampling, False, precomps, False, L_lower
    254     )

    [... skipping hidden 11 frame]

File /pscratch/sd/s/shamikg/cmbenv/master-0.0.1/conda/lib/python3.10/site-packages/s2fft/transforms/spherical.py:319, in inverse_jax(flm, L, spin, nside, sampling, reality, precomps, spmd, L_lower)
    315     ftm = ftm.at[:, m_offset : L - 1 + m_offset].set(
    316         jnp.flip(jnp.conj(ftm[:, L - 1 + m_offset + 1 :]), axis=-1)
    317     )
    318 if sampling.lower() == \"healpix\":
--> 319     return hp.healpix_ifft(ftm, L, nside, \"jax\")
    320 else:
    321     ftm = jnp.conj(jnp.fft.ifftshift(ftm, axes=1))

File /pscratch/sd/s/shamikg/cmbenv/master-0.0.1/conda/lib/python3.10/site-packages/s2fft/utils/healpix_ffts.py:398, in healpix_ifft(ftm, L, nside, method, reality)
    368 def healpix_ifft(
    369     ftm: np.ndarray,
    370     L: int,
   (...)
    373     reality: bool = False,
    374 ) -> np.ndarray:
    375     \"\"\"Wrapper function for the Inverse Fast Fourier Transform with spectral folding
    376     in the polar regions to mitigate aliasing.
    377 
   (...)
    396         np.ndarray: HEALPix pixel-space array.
    397     \"\"\"
--> 398     assert L >= 2 * nside
    399     if method.lower() == \"numpy\":
    400         return healpix_ifft_numpy(ftm, L, nside, reality)

AssertionError: 
MichaelJacob914 commented 3 weeks ago

I am also having this issue, I did not have this issue before the April commits and have not altered the code since.

1cosmologist commented 3 weeks ago

Doesn't seem like April commits. I reinstalled from source after a reset to 8th March commit: 4baa256. I also switched s2fft to v1.0.2. I still get the same AssertionError.

jasonmcewen commented 3 weeks ago

I'm not sure whether a Healpix map is supported as an input argument here. You might need to compute the alms from the Healpix map and pass that in to the appropriate function. @CosmoMatt, can you comment?

1cosmologist commented 3 weeks ago

@jasonmcewen I have tried both. There is a sampling option that can take healpix for the analysis function. But I have also tried with alm as input using s2wav.wavelet.flm_to_analysis. In both cases I encounter AssertionError.

jasonmcewen commented 3 weeks ago

Ok, hopefully @CosmoMatt will be able to help shortly...

CosmoMatt commented 3 weeks ago

Ah ok I see the issue @jasonmcewen @1cosmologist. We defaulted the wavelet transform to multi-resolution algorithm (see section 3.1 of this paper) but the underlying Wigner and harmonic transforms (from the s2fft package) are not really configured to interact well with this behaviour specifically for HEALPix sampling, hence the error being thrown.

We could add support for this but (and its a big but), the wavelet transform for HEALPix sampling will be very inaccurate because (a) we don't yet have support for iterations, without which the HEALPix SHT is extremely inaccurate and (b) this error gets compounded in the Wigner transform.

@1cosmologist if you aren't married to HEALPix sampling, and would still like to pick up our JAX wavelets, you can do this very straightforwardly by converting between HEALPix and any other sampling in harmonic space. For example, to convert to MW sampling you could run the following:

import s2fft
import healpy

nside = 128
L = 2 * nside 
f_hp = np.ones(12*nside**2)
flm_hp = healpy.map2alm(f_hp, L)
flm_mw = s2fft.sampling.s2_samples.flm_hp_to_2d(flm_hp, L)
f_mw = s2fft.inverse(flm_mw, L)

Then all the functionality should be supported, and all transforms should be exact to machine precision. You are also entirely free to convert back to HEALPix at the end of your analysis by instead calling

flm_hp = s2fft.sampling.s2_samples.flm_2d_to_hp(flm_mw, L)
CosmoMatt commented 3 weeks ago

@jasonmcewen at the very least we should update the docstrings to indicate that we support MW, MWSS, DH, GL but not HEALPix at the moment. I can make a PR for this when I get a chance.

jasonmcewen commented 3 weeks ago

Thanks very much for the comments @CosmoMatt !

We should certainly support HEALPix format of input data but, indeed, I'm not sure we should support HEALPix internally for the map representation since accuracy would drop considerably, as you say.

I think @1cosmologist may have run into the issue that the alm interface didn't work due to the different alm storage formats?

So, for now, perhaps we should simply include a demo notebook showing how to run on a HEALPix map? That is, compute alms either with s2fft or healpy (if using healpy then convert the alm format), then pass the alms to the wavelet transform. Could you add a quick notebook @CosmoMatt when you get a chance?

1cosmologist commented 3 weeks ago

@jasonmcewen @CosmoMatt I appreciate that there is a simple way to change sampling schemes for HEALPix maps. I will try this way and update. But I am surprised that s2wav.wavelet.flm_to_analysis did not work when I computed the alm with s2fft. I checked that the shape of the alm produced was 2-dimensional and not 1D as would be for healpy.

Here is what I did to compute the alm for the IRAS 100 micron dust map (at NSIDE=128). iris_lm = s2fft.transforms.spherical.forward(iris, lmax, spin = 0, nside = nside, sampling = 'healpix', method = 'jax', reality = True, L_lower = 0)

CosmoMatt commented 3 weeks ago

Hey @1cosmologist just to be clear, you're generating iris_lm then attempting to run flm_to_analysis on these coefficients but with sampling != "healpix" ? If you choose sampling from ["mw", "mwss", "gl", "dh"] the flm_to_analysis should work I think. In any case, this functionality was specifically added for certain applications collaborators were working with and really shouldn't have been exposed to users as its fairly bespoke. For general use I would recommend using the main transforms where possible.

We are looking to remove support across the board for HEALPix here, as it no longer operates correctly with required packages.

1cosmologist commented 3 weeks ago

Hi @CosmoMatt, I was running flm_to_analysis with sampling = healpix. That may be the issue. I am not sure what sampling would do here. Since it is not a healpix map input. However, I think the cleanest way of doing this is as you described above.

I hope there are plans to keep supporting healpy alm to s2fft 2d alm. I think that is all that should be needed. One can write a simple wrapper function to do it all, both forward and backward transform.