astro-informatics / s2wav

Differentiable and accelerated wavelet transform on the sphere with JAX
https://astro-informatics.github.io/s2wav/
MIT License
12 stars 0 forks source link

Wavelets in real space #79

Open MichaelJacob914 opened 3 months ago

MichaelJacob914 commented 3 months ago

Hi, I was just wondering if there is a built-in implementation to transform the harmonic coefficients of the wavelets back into real space to visualize the wavelets on the sphere as depicted in the ReadMe.

jasonmcewen commented 2 months ago

Thanks for the comment @MichaelJacob914 ! I'm not sure we have that built-in at present but it would be useful so we should certainly add it.

(Tagging @CosmoMatt.)

CosmoMatt commented 2 months ago

Hi @MichaelJacob914 as @jasonmcewen says, we haven't included this in the package to reduce the number of additional requirements users need to install. This is something we may well add very soon, perhaps just in the notebook directory and not directly part of the package.

In any case, this is a code snippet that should produce the kind of images you're looking for:

import numpy as np
from mayavi import mlab
import s2wav
import s2fft
from s2fft.sampling import s2_samples as samples

"""
Requirements:
    - s2fft
    - s2wav
    - numpy 
    - mayavi 
"""

def plot_sphere(f: np.ndarray, L: int, sr: float, mx: float, mn: float):

    # Define meshgrid points on spherical surface
    phis = samples.phis_equiang(L, sampling="mw")
    thetas = samples.thetas(L, sampling="mw")

    # Fix continuity at boundaries for visualisation
    thetas[0] = 0
    phis[-1] = 2 * np.pi

    # Generate angular meshgrid
    phi, theta = np.meshgrid(phis, thetas)

    # Scaling to increase/decrease magnitude of coefficient for visualisation
    temp = (f - mn) / mx
    r = sr + temp

    # Convert angular meshgrid to cartesian
    x = r * np.sin(theta) * np.cos(phi)
    y = r * np.sin(theta) * np.sin(phi)
    z = r * np.cos(theta)

    # 3D render using mayavi package.
    mlab.figure(1, bgcolor=(1, 1, 1), fgcolor=(0, 0, 0), size=(300, 300))
    mlab.clf()
    mlab.mesh(x, y, z, scalars=temp, colormap="viridis",vmax=1-mn/mx,vmin=0)
    mlab.show()

if __name__ == "__main__":
    L = 128   # Harmonic bandlimit
    N = 3     # Azimuthal bandlimit
    sr = 5    # How much to scale the magnitude on the sphere for visualisation.

    # Generate the wavelet kernel in spherical harmonic space
    # Note: this generates the directional kernel, but the actual directionality is  
    #       introduced during the convolution. So this is just a single set of harmonic
    #       coefficients which you can imagine being rotated during the transform.
    J_max = s2wav.samples.j_max(L)
    wav_lmn, scal_lm = s2wav.filters.filters_directional_vectorised(L, N)

    # Inverse spherical harmonic transform and plot the wavelet filters on the sphere.
    for j in range(J_max):
        wavelet_harmonic_coeffs = wav_lmn[j]
        wavelet_coeffs = s2fft.inverse(wavelet_harmonic_coeffs, L)

        mx, mn = np.nanmax(np.real(wavelet_coeffs)), np.nanmin(np.real(wavelet_coeffs))
        plot_sphere(np.real(wavelet_coeffs), L, sr, mx, mn)

If you're looking for something else, or having issues with this snippet, let me know and I can try and point you in the right direction!

MichaelJacob914 commented 4 weeks ago

Thank you so much! I'm getting an error that wavelet_pixel_space isn't defined. Is there a specific tiling scheme I should use to define this?

CosmoMatt commented 4 weeks ago

@MichaelJacob914 sorry there is a typo, the final line should read "plot_sphere(np.real(wavelet_coeffs), L, sr, mx, mn)" I suspect!