Open ASKabalan opened 1 month ago
Thanks for the suggestion @ASKabalan! At present we only have PyTorch support for the precompute approach but we do plan to add PyTorch support for the on-the-fly transforms as well. So it would be good to have a nice setup like you're suggesting.
@matt-graham do you have any thoughts on this or know of best practices?
Hi @jasonmcewen. This is related to what I suggested in #224 - we can easily make all of PyTorch, pyssht
and healpy
optional dependencies and guard the imports from these packages within try: ... except ImportError: ...
logic, either at a module level or within the relevant functions themselves, and raise an informative error if a user tries to use a function for which the relevant optional dependencies are not available.
I would say dropping hard requirements on all external packages which are not 'core' to the package would be good practice for a library like s2fft
as it avoids forcing users who only want to rely on a subset of the functionality to install unnecessary dependencies, and as I mentioned in #224 extends the range of systems that users can install the package on. The overhead for those users who wish to use the optional features is minimal as we can add relevant optional dependencies groups so that for example they could just do pip install s2fft[all]
to install all optional dependencies, or pip install s2fft[torch]
to just install extra dependencies required for PyTorch support and so on.
Just noticed suggestion to also apply this to JAX - while we could do this for JAX too, as a lot of the modules (perhaps the majority?) make multiple imports from jax
package namespace, this would potentially get a bit unwieldy to guard all these imports. What might help with reducing our explicit imports from JAX and PyTorch APIs (and duplication of functions) is to use their support for the array API, by doing something like
import array_api_compat
...
def spectral_periodic_extension(fm, L: int):
xp = array_api_compat.array_namespace(fm)
nphi = fm.shape[0]
return xnp.concatenate(
(
fm[-xnp.arange(L - nphi // 2, 0, -1) % nphi],
fm,
fm[xnp.arange(L - (nphi + 1) // 2) % nphi],
)
)
This would make the function compatible with all of NumPy, JAX and PyTorch arrays without requiring an explicit import from any of them and also avoid having multiple _jax
, _torch
variants of functions.
Hi @jasonmcewen. This is related to what I suggested in #224 - we can easily make all of PyTorch,
pyssht
andhealpy
optional dependencies and guard the imports from these packages withintry: ... except ImportError: ...
logic, either at a module level or within the relevant functions themselves, and raise an informative error if a user tries to use a function for which the relevant optional dependencies are not available.I would say dropping hard requirements on all external packages which are not 'core' to the package would be good practice for a library like
s2fft
as it avoids forcing users who only want to rely on a subset of the functionality to install unnecessary dependencies, and as I mentioned in #224 extends the range of systems that users can install the package on. The overhead for those users who wish to use the optional features is minimal as we can add relevant optional dependencies groups so that for example they could just dopip install s2fft[all]
to install all optional dependencies, orpip install s2fft[torch]
to just install extra dependencies required for PyTorch support and so on.
Ok, fair enough @matt-graham ! Does this mean users would need to install from source rather than PyPi tho? Perhaps that is not a big issue anyway.
Just noticed suggestion to also apply this to JAX - while we could do this for JAX too, as a lot of the modules (perhaps the majority?) make multiple imports from
jax
package namespace, this would potentially get a bit unwieldy to guard all these imports. What might help with reducing our explicit imports from JAX and PyTorch APIs (and duplication of functions) is to use their support for the array API, by doing something likeimport array_api_compat ... def spectral_periodic_extension(fm, L: int): xp = array_api_compat.array_namespace(fm) nphi = fm.shape[0] return xnp.concatenate( ( fm[-xnp.arange(L - nphi // 2, 0, -1) % nphi], fm, fm[xnp.arange(L - (nphi + 1) // 2) % nphi], ) )
This would make the function compatible with all of NumPy, JAX and PyTorch arrays without requiring an explicit import from any of them and also avoid having multiple
_jax
,_torch
variants of functions.
Maybe. This seems like quite a big revision tho. Let's discuss further...
Ok, fair enough @matt-graham ! Does this mean users would need to install from source rather than PyPi tho? Perhaps that is not a big issue anyway.
No, the extras / optional dependencies syntax works fine for both packages installed from a local directory and a packaging index like PyPI. It's the same syntax that JAX uses to install from PyPI with CUDA support:
pip install jax[cuda12]
Ok, sounds like it makes a lot of sense then.
Currently, we have to install pytorch to use the JAX backend of s2fft.
I think it would be a nice to be able to conditionally activate backends depending on the availability of the package
For example I user can use s2FFT using only numpy if he does not have JAX or pytorch Otherwise he gets a runtime error instead of an import time error
Same for pyssht. Don't know the best practice to do this, but a try catch around the import + global boolean should do the trick