astro-informatics / s2fft

Differentiable and accelerated spherical transforms with JAX
https://astro-informatics.github.io/s2fft
MIT License
133 stars 9 forks source link

Inconsistent results for pure JAX implementation of inverse Wigner transformation #209

Closed ElisR closed 1 month ago

ElisR commented 2 months ago

Hi! Firstly, thanks a lot for developing this package! I'm not in astro, but the package has been just what I needed in a project I'm working on.

The issue I'm having is related to the JAX implementation of the inverse Wigner transformation s2fft.wigner.inverse_jax. I was lucky to stumble upon the SSHT version of this function that allowed me to implement what I needed, albeit on the CPU.

Below, I have added a minimum (non-)working example to illustrate the kind of use case I have. Apologies that it is still slightly verbose. I haven't delved deep into the library code to try and debug, but from the scales it looks like some catastrophic floating point errors are happening (beyond aliasing artefacts, which is what I originally assumed).

"""Minimum example showing `s2fft`'s inverse Wigner transformation not working reliably."""

from collections.abc import Callable
from typing import TypeAlias

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax.scipy.spatial import transform

# Allowing 64-bit float before importing s2fft
jax.config.update("jax_enable_x64", True)
import s2fft  # noqa: E402  pylint: disable=C0413
from s2fft import partial, sampling  # noqa: E402  pylint: disable=C0413

ScalarFunction: TypeAlias = Callable[[jax.Array], jax.Array]

S2FFT_ANGLE_SAMPLER: str = "dh"
S2FFT_REALITY: bool = False

def dxy_orbital(coords: jax.Array) -> jax.Array:
    """(Arbitrary) function that looks like a 3dxy orbital."""
    # Scaling coordinates to get shape mostly contained in frame.
    coords *= 8.0
    x_s = coords[..., 0]
    y_s = coords[..., 1]
    r_s = jnp.linalg.vector_norm(coords, axis=-1)
    return jnp.exp(-r_s / 2) * x_s * y_s

def _generate_samples(func: ScalarFunction, thetas: jax.Array, phis: jax.Array) -> jax.Array:
    """Generate a signal using the given sampling angles.

    Axes will correspond to `(theta, phi)`.
    """
    x_s = jnp.sin(thetas)[:, None] * jnp.cos(phis)[None, :]
    y_s = jnp.sin(thetas)[:, None] * jnp.sin(phis)[None, :]
    z_s = jnp.cos(thetas)[:, None] + 0 * x_s
    coords = jnp.stack([x_s, y_s, z_s], axis=-1)
    return func(coords)

def get_correlation(
    func: ScalarFunction, rotation: jax.Array, lmax: int, *, use_ssht: bool
) -> jax.Array:
    """Show that `jax` wrapper for inverse Wigner is not behaving reliably."""

    def _rotated_func(func: ScalarFunction, rotation: jax.Array, coords: jax.Array) -> jax.Array:
        """Wrap a scalar function with its rotated version."""
        rotated_coords = jnp.einsum("ij,...j->...i", rotation.T, coords)
        return func(rotated_coords)

    func_rotated = partial(_rotated_func, func, rotation)

    thetas = jnp.array(sampling.s2_samples.thetas(lmax, sampling=S2FFT_ANGLE_SAMPLER))
    phis = jnp.array(sampling.s2_samples.phis_equiang(lmax, sampling=S2FFT_ANGLE_SAMPLER))

    def _get_coeffs(func: ScalarFunction) -> jax.Array:
        sampled = _generate_samples(func, thetas, phis)
        coeffs = s2fft.forward_jax(
            sampled, lmax, reality=S2FFT_REALITY, sampling=S2FFT_ANGLE_SAMPLER
        )
        return coeffs

    func_coeffs = _get_coeffs(func)
    func_rotated_coeffs = _get_coeffs(func_rotated)

    outer_product = (
        jnp.matrix_transpose(jnp.conjugate(func_rotated_coeffs))[:, :, None]
        * func_coeffs[None, :, :]
    )
    assert outer_product.dtype == jnp.complex128

    # This is the part that varies
    inverse_args = (outer_product, lmax, lmax)
    inverse_kwargs = {"reality": S2FFT_REALITY, "sampling": S2FFT_ANGLE_SAMPLER}
    so3_correlation = (
        s2fft.wigner.inverse_jax_ssht(*inverse_args, **inverse_kwargs)
        if use_ssht
        else s2fft.wigner.inverse_jax(*inverse_args, **inverse_kwargs)
    )
    return jnp.real(so3_correlation)

def plot_strange_results() -> None:
    """Plot strange results from the above function.."""
    print(f"{jax.config.read('jax_enable_x64')=}")

    # Get scores
    rotation = transform.Rotation.from_rotvec(jnp.array([0, 0, jnp.pi / 2])).as_matrix()
    lmax = 64
    ssht_score = get_correlation(dxy_orbital, rotation, lmax, use_ssht=True)
    jax_score = get_correlation(dxy_orbital, rotation, lmax, use_ssht=False)

    # Choose angles for plotting
    best_index = jnp.unravel_index(jnp.argmax(ssht_score), jnp.shape(ssht_score))
    best_gamma, best_beta, best_alpha = best_index
    betas = jnp.array(sampling.s2_samples.thetas(lmax, sampling=S2FFT_ANGLE_SAMPLER))
    alphas = gammas = jnp.array(
        sampling.s2_samples.phis_equiang(lmax, sampling=S2FFT_ANGLE_SAMPLER)
    )

    # Plot output
    fig, axes = plt.subplots(ncols=3, nrows=2, squeeze=False)
    plt.subplots_adjust(left=0.1, right=0.9, top=0.95, bottom=0.05, hspace=0.4, wspace=0.3)
    for i, (score, name) in enumerate([(ssht_score, "SSHT"), (jax_score, "JAX")]):
        axes[i, 0].plot(alphas, score[best_gamma, best_beta, :])
        axes[i, 0].title.set_text(f"Corr vs α for {name}")
        axes[i, 1].plot(betas, score[best_gamma, :, best_alpha])
        axes[i, 1].title.set_text(f"Corr vs β for {name}")
        axes[i, 2].plot(gammas, score[:, best_beta, best_alpha])
        axes[i, 2].title.set_text(f"Corr vs γ for {name}")
    fig.savefig("strange_s2fft_inverse_wigner.png", bbox_inches="tight", dpi=300)
    plt.close(fig)

if __name__ == "__main__":
    plot_strange_results()
ElisR commented 2 months ago

Here's the output I'm getting on Apple Silicon MacOS (running on CPU), though my colleague on Ubuntu also gets the same thing on CPU.

strange_s2fft_inverse_wigner

jasonmcewen commented 2 months ago

Thanks for your comments @ElisR . We'll try to look into this soon. We've tested the Wigner transforms and they should be working fine but there is a known issue that they are only stable for relative low azimuthal bandlimit N at present (around O(10)). I believe a warning/error is displayed if one tries to go beyond that (is that the case @CosmoMatt?). We have plans to fix this in the very near future. Is this the issue you were running into or is this something different (apologies, I haven't had a chance to look at your code example yet)?

ElisR commented 1 month ago

Thanks for your response @jasonmcewen and apologies for my late reply.

You're right that the issue must be coming from the large azimuthal bandlimit N that I require (for aligning two spherical signals). I now also see the warning you mention when calling s2fft.transforms.wigner.inverse. Unfortunately that warning doesn't appear when calling inverse_jax directly.

Perhaps as a temporary measure I will make a PR that adds this note to the docstring of inverse_jax.

I'll be curious to know what's causing the numerical instability at large N. Best of luck with the fix - I really love the package otherwise!

jasonmcewen commented 1 month ago

Thanks for catching this @ElisR ! We're getting back into further development of s2fft again now after some other commitments. If you could create a PR to add a warning that would be fantastic. We'll then add you to all contributors for the code.

In terms of the numerical stability, it's related to the recursions that we use. We've already implemented alternative recursions that don't suffer from this instability, although they may not distribute over GPUs quite so efficiently. We just need to integrate these into the harmonic transforms and will try to get to that soon. Let us know if you're interested in helping out with this and we could help to support that? Otherwise we'll try to get to it soon but we have a few other todos first.

ElisR commented 1 month ago

Hi @jasonmcewen - made a PR #220 to acknowledge the low accuracy beyond N = 8 in the docstring.

I'd be happy to try and help bring over the stable recursions into the transforms, if you could point me to the existing work. Is there a branch with these recursions, or are they already in main? Once I've had a quick look then I should be able to say whether I think I can implement it myself in reasonable time in the next couple of weeks.

jasonmcewen commented 1 month ago

Thanks for this @ElisR . We'll review this shorty and get into main.

Great that you're interested in getting involved in the integration of the alternative recursions. Let's pick up the discussion about that in another issue shortly...

jasonmcewen commented 1 month ago

Closing this issue since the documentation updates regarding warnings were added in this PR and we'll pick up the discussion about integrating other recursions for stability for high azimuthal bandlimits in this issue.