astro-informatics / s2fft

Differentiable and accelerated spherical transforms with JAX
https://astro-informatics.github.io/s2fft
MIT License
133 stars 9 forks source link

[Suggestion] : Remove strict requirement on backends. #229

Open ASKabalan opened 1 month ago

ASKabalan commented 1 month ago

Currently, we have to install pytorch to use the JAX backend of s2fft.

I think it would be a nice to be able to conditionally activate backends depending on the availability of the package

For example I user can use s2FFT using only numpy if he does not have JAX or pytorch Otherwise he gets a runtime error instead of an import time error

Same for pyssht. Don't know the best practice to do this, but a try catch around the import + global boolean should do the trick

jasonmcewen commented 1 month ago

Thanks for the suggestion @ASKabalan! At present we only have PyTorch support for the precompute approach but we do plan to add PyTorch support for the on-the-fly transforms as well. So it would be good to have a nice setup like you're suggesting.

@matt-graham do you have any thoughts on this or know of best practices?

matt-graham commented 1 month ago

Hi @jasonmcewen. This is related to what I suggested in #224 - we can easily make all of PyTorch, pyssht and healpy optional dependencies and guard the imports from these packages within try: ... except ImportError: ... logic, either at a module level or within the relevant functions themselves, and raise an informative error if a user tries to use a function for which the relevant optional dependencies are not available.

I would say dropping hard requirements on all external packages which are not 'core' to the package would be good practice for a library like s2fft as it avoids forcing users who only want to rely on a subset of the functionality to install unnecessary dependencies, and as I mentioned in #224 extends the range of systems that users can install the package on. The overhead for those users who wish to use the optional features is minimal as we can add relevant optional dependencies groups so that for example they could just do pip install s2fft[all] to install all optional dependencies, or pip install s2fft[torch] to just install extra dependencies required for PyTorch support and so on.

matt-graham commented 1 month ago

Just noticed suggestion to also apply this to JAX - while we could do this for JAX too, as a lot of the modules (perhaps the majority?) make multiple imports from jax package namespace, this would potentially get a bit unwieldy to guard all these imports. What might help with reducing our explicit imports from JAX and PyTorch APIs (and duplication of functions) is to use their support for the array API, by doing something like

import array_api_compat

...

def spectral_periodic_extension(fm, L: int):
    xp = array_api_compat.array_namespace(fm)
    nphi = fm.shape[0]
    return xnp.concatenate(
        (
            fm[-xnp.arange(L - nphi // 2, 0, -1) % nphi],
            fm,
            fm[xnp.arange(L - (nphi + 1) // 2) % nphi],
        )
    )

This would make the function compatible with all of NumPy, JAX and PyTorch arrays without requiring an explicit import from any of them and also avoid having multiple _jax, _torch variants of functions.

jasonmcewen commented 1 month ago

Hi @jasonmcewen. This is related to what I suggested in #224 - we can easily make all of PyTorch, pyssht and healpy optional dependencies and guard the imports from these packages within try: ... except ImportError: ... logic, either at a module level or within the relevant functions themselves, and raise an informative error if a user tries to use a function for which the relevant optional dependencies are not available.

I would say dropping hard requirements on all external packages which are not 'core' to the package would be good practice for a library like s2fft as it avoids forcing users who only want to rely on a subset of the functionality to install unnecessary dependencies, and as I mentioned in #224 extends the range of systems that users can install the package on. The overhead for those users who wish to use the optional features is minimal as we can add relevant optional dependencies groups so that for example they could just do pip install s2fft[all] to install all optional dependencies, or pip install s2fft[torch] to just install extra dependencies required for PyTorch support and so on.

Ok, fair enough @matt-graham ! Does this mean users would need to install from source rather than PyPi tho? Perhaps that is not a big issue anyway.

jasonmcewen commented 1 month ago

Just noticed suggestion to also apply this to JAX - while we could do this for JAX too, as a lot of the modules (perhaps the majority?) make multiple imports from jax package namespace, this would potentially get a bit unwieldy to guard all these imports. What might help with reducing our explicit imports from JAX and PyTorch APIs (and duplication of functions) is to use their support for the array API, by doing something like

import array_api_compat

...

def spectral_periodic_extension(fm, L: int):
    xp = array_api_compat.array_namespace(fm)
    nphi = fm.shape[0]
    return xnp.concatenate(
        (
            fm[-xnp.arange(L - nphi // 2, 0, -1) % nphi],
            fm,
            fm[xnp.arange(L - (nphi + 1) // 2) % nphi],
        )
    )

This would make the function compatible with all of NumPy, JAX and PyTorch arrays without requiring an explicit import from any of them and also avoid having multiple _jax, _torch variants of functions.

Maybe. This seems like quite a big revision tho. Let's discuss further...

matt-graham commented 1 month ago

Ok, fair enough @matt-graham ! Does this mean users would need to install from source rather than PyPi tho? Perhaps that is not a big issue anyway.

No, the extras / optional dependencies syntax works fine for both packages installed from a local directory and a packaging index like PyPI. It's the same syntax that JAX uses to install from PyPI with CUDA support:

pip install jax[cuda12]
jasonmcewen commented 1 month ago

Ok, sounds like it makes a lot of sense then.