Implemented Macenko with JAX backend and added it to base
Added JAX unit test CI jobs (check that JAX yields similar results to numpy backend)
Renamed CI names to better match their actual purpose
Updated README regarding JAX backend support
Fixed setup.py to support installation through pip install torchstain[jax]
Fixed np.float32 deprecation in numpy macenko
Removed unwanted numpy import in macenko tf backend
Note that the JAX backend runtime-wise is not as optimized as the other backends. Hence, I would perhaps say that we only have experimental JAX support as of now. Here is how JAX backend compared to the other backends:
backends
numpy
jax
torch
tf
runtime [s]
0.455
2.427
0.201
0.442
Further optimization to the JAX implementation should be done in future work, but this is outside my area of expertise. Hence, for that, it would be great if more experienced JAX developers could contribute.
This PR adds JAX backend support to Macenko.
Changes:
pip install torchstain[jax]
np.float32
deprecation in numpy macenkoNote that the JAX backend runtime-wise is not as optimized as the other backends. Hence, I would perhaps say that we only have
experimental JAX support
as of now. Here is how JAX backend compared to the other backends:Further optimization to the JAX implementation should be done in future work, but this is outside my area of expertise. Hence, for that, it would be great if more experienced JAX developers could contribute.