While Numba enables JIT and parallelized Python code on the CPU, the GPU aspect is lacking in features and regressing (see the depreciation of AMD support). One aspect of this problem is the batching of data #24.
JAX is a solid alternative to Numba and ships with similar features and extends them even further. Some notable differences are:
Code can be executed on CPU, GPU, and TPU (Tensor Processing Unit, usually used on Google Cloud solution) by switching an environmental variable
GPU incorporates Nvidia for now, but an experimental AMD integration is currently available and functions to a certain degree
Vector mapping enables easy parallelization of the code and possibly batching using jax.lax.map
Using an environmental variable, one can toggle between single and double precision, with single precision usually being enough but speeding things up!
Compared to Numba, JAX has a whole "ecosystem" of modules which build on top of it, making scientific computing with Numba much easier
A package rework is planned and done on a separate branch until all tests pass!
While Numba enables JIT and parallelized Python code on the CPU, the GPU aspect is lacking in features and regressing (see the depreciation of AMD support). One aspect of this problem is the batching of data #24.
JAX is a solid alternative to Numba and ships with similar features and extends them even further. Some notable differences are:
jax.lax.map
A package rework is planned and done on a separate branch until all tests pass!