astro-informatics / s2scat

Differentiable and GPU accelerated scattering covariance statistics on the sphere
https://astro-informatics.github.io/s2scat/
MIT License
6 stars 1 forks source link
compression differentiable-programming emulation generative-model jax scattering-transform spherical statistics wavelets

image codecov image PyPi version PyPI download month All Contributors Open In Colab

Differentiable scattering covariances on the sphere

S2SCAT is a Python package for computing scattering covariances on the sphere (Mousset et al. 2024) using JAX. It exploits autodiff to provide differentiable transforms, which are also deployable on hardware accelerators (e.g. GPUs and TPUs), leveraging the differentiable and accelerated spherical harmonic and wavelet transforms implemented in S2FFT and S2WAV, respectively. Scattering covariances are useful both for field-level generative modelling of complex non-Gaussian textures and for statistical compression of high dimensional field-level data, a key step of e.g. simulation based inference.

[!IMPORTANT] It is worth highlighting that the input to S2SCAT are spherical harmonic coefficients, which can be generated with whichever software package you prefer, e.g. S2FFT or healpy. Just ensure your harmonic coefficients are indexed using our convention; helper functions for this reindexing can be found in S2FFT.

[!TIP] At launch S2SCAT provides two core transform modes: on-the-fly, which performs underlying spherical harmonic and Wigner transforms through the Price & McEwen recursion; and precompute, which a priori computes and caches all Wigner elements required. The precompute approach will be faster but can only be run up to $L \sim 512$, whereas the on-the-fly approach will run up to $L \sim 2048$ and potentially beyond, depending on GPU hardware.

Ballpark compute times (when running on an 40GB A100 GPU) and compression levels are given in the table below.

Method Resolution Forward pass Gradient pass JIT compilation Input params Anisotropic (compression) Isotropic (compression)
Precompute L=512, N=3 ~90ms ~190ms ~20s 2,618,880 ~ 63,000 (97.594%) ~504 (99.981%)
On-the-fly L=2048, N=3 ~18s ~40s ~5m 41,932,800 ~ 123,750 (99.705%) ~ 990 (99.998%)

Note that these times are not batched, so in practice may be substantially faster. For example, with a large batch size at $L=256, N=3$ textures are generated at a rate of 500ms per texture on a single 40GB NVIDIA A100 GPU. Further note that everything here is running at 64 bit precision, therefore relaxing to 32 bit precision reduces computation time and memory by a factor of 2, and facilitates larger batching for further acceleration.

Scattering covariances :dna:

We introduce scattering covariances on the sphere in Mousset et al. (2024), which extend to spherical settings similar scattering transforms introduced for 1D signals by Morel et al. (2023) and for planar 2D signals by Cheng et al. (2023). Scattering covariances $S$ are computed by

$$S_1^{\lambda_1} = \langle |W^{\lambda_1} I| \rangle,$$

$$S_2^{\lambda_1} = \langle|W^{\lambda_1} I|^2 \rangle,$$

$$S_3^{\lambda_1, \lambda_2} = \text{Cov} \left[ W^{\lambda_1}I, W^{\lambda_1}|W^{\lambda_2} I| \right],$$

$$S_4^{\lambda_1, \lambda_2, \lambda_3} = \text{Cov} \left[W^{\lambda_1}|W^{\lambda_3}I|, W^{\lambda_1}|W^{\lambda_2}I|\right]$$

where $W^{\lambda} I$ denotes the wavelet transform of field $I$ at scale $j$ and direction $\gamma$, which we group into a single label $\lambda=(j,\gamma)$.

This statistical representation characterises the power and sparsity at given scales, as well as covariant features between different wavelet scale and directions, which can effectively capture complex non-Gaussian structural information, e.g. filamentary structure.

Using the recently released JAX spherical harmonic code S2FFT (Price & McEwen 2024) and spherical wavelet transform code S2WAV (Price et al. 2024) in the S2SCAT code we extends scattering covariances to the sphere, which are necessary for their application to generative modelling of wide-field cosmological fields (Mousset et al. 2024).

Usage :rocket:

To import and use S2SCAT is as simple follows:

import s2scat, jax
# For statistical compression
encoder = s2scat.build_encoder(L, N)          # Returns a callable compression model.
covariance_statistics = encoder(alm)          # Generate statistics (can be batched).

# For generative modelling
key = jax.random.PRNGKey(seed)
generator = s2scat.build_generator(alm, L, N) # Returns a callable generative model.
new_samples = generator(key, 10)              # Generate 10 new spherical textures. 

For further details on usage see the documentation and associated notebooks.

Package Directory Structure :art:

s2scat/  
├── representation.py   # - Scattering covariance transform.
├── compression.py      # - Statistical compression functions.
├── optimisation.py     # - Optimisation algorithm wrappers. 
├── generation.py       # - Latent encoder and Generative decoder.
│    
├── operators/          # Internal functionality:
│      ├─ spherical.py          # - Specific spherical operations, e.g. batched SHTs.
│      ├─ matrices.py           # - Wrappers to generate cached values. 
│
├── utility/            # Convenience functionality:
│      ├─ reorder.py            # - Reindexing and converting list and arrays.
│      ├─ statistics.py         # - Calculation of covariance statistics. 
│      ├─ normalisation.py      # - Normalisation functions for covariance statistics. 
│      ├─ plotting.py           # - Plotting functions for signals and statistics.

Installation :computer:

The Python dependencies for the S2SCAT package are listed in the file requirements/requirements-core.txt and will be automatically installed into the active python environment by pip when running

pip install s2scat

This will install all core functionality which includes full JAX support.

Alternatively, the S2SCAT package may be installed directly from GitHub by cloning this repository and then running

pip install .        

from the root directory of the repository.

Unit tests can then be executed to ensure the installation was successful by first installing the test requirements and then running pytest

pip install -r requirements/requirements-tests.txt
pytest tests/  

Documentation for the released version is available here.

Contributors

Matt Price
Matt Price

🤔 💻 🎨 📖
mousset
mousset

💻 🎨 🤔
Jason McEwen
Jason McEwen

🤔 💻 📖
Eralys
Eralys

🤔

Attribution :books:

Should this code be used in any way, we kindly request that the following article is referenced. A BibTeX entry for this reference may look like:

    @article{mousset:s2scat, 
        author      = "Louise Mousset et al",
        title       = "TBD",
        journal     = "TBD, submitted",
        year        = "2024",
        eprint      = "TBD"        
    }

You might also like to consider citing our related papers on which this code builds:

    @article{price:s2fft, 
        author      = "Matthew A. Price and Jason D. McEwen",         
        title        = "Differentiable and accelerated spherical harmonic and {W}igner transforms",
        journal      = "Journal of Computational Physics",
        volume       = "510",
        pages        = "113109",        
        year         = "2024",
        doi          = {10.1016/j.jcp.2024.113109},
        eprint       = "arXiv:2311.14670"        
    }
    @article{price:s2wav, 
        author      = "Matthew A. Price and Alicja Polanska and Jessica Whitney and Jason D. McEwen",
        title       = "Differentiable and accelerated directional wavelet transform on the sphere and ball",
        year        = "2024",
        eprint      = "arXiv:2402.01282"
    }

License :memo:

We provide this code under an MIT open-source licence with the hope that it will be of use to a wider community.

Copyright 2024 Louise Mousset, Matthew Price, Erwan Allys and Jason McEwen

S2SCAT is free software made available under the MIT License. For details see the LICENSE file.