Open jakharkaran opened 5 months ago
Also at low resolutions, numpy
is faster than gpu-jax
, so ability to run on either is beneficential, falling back to numpy
when jax
is not installed is also favorable
For the solver, At low resolution, JAX-CPU is faster than JAX-GPU. I haven't tested for numpy. I can mention this in the readme for now.
I mean for this reason, it is good to have it as a selectionable option, even for the runs
A colleague just told me about the Python "Array API" concept, which might be useful reading for this: https://data-apis.org/array-api/latest/purpose_and_scope.html#stakeholders
This repository contains a 2D Navier-Stokes equation solver and data processing methods. The solver, written using the JAX library, is computationally expensive and leverages GPU acceleration. The less intensive post-processing methods use NumPy.
Some functions are required by both the solver and post-processing. Currently, duplicate copies exist – one for JAX and one for NumPy. What is the best way to optimize this code structure?
backend = 'numpy'
or'jax'
: Suitable for functions where the underlying structure is identical between NumPy and JAX, with only the library calls differing