envfluids / py2d

Python 2D Navier-Stokes solver
MIT License
12 stars 2 forks source link

JAX and Numpy functions: How to structure them? #58

Open jakharkaran opened 5 months ago

jakharkaran commented 5 months ago

This repository contains a 2D Navier-Stokes equation solver and data processing methods. The solver, written using the JAX library, is computationally expensive and leverages GPU acceleration. The less intensive post-processing methods use NumPy.

Some functions are required by both the solver and post-processing. Currently, duplicate copies exist – one for JAX and one for NumPy. What is the best way to optimize this code structure?

rmojgani commented 5 months ago

Also at low resolutions, numpy is faster than gpu-jax, so ability to run on either is beneficential, falling back to numpy when jax is not installed is also favorable

jakharkaran commented 5 months ago

For the solver, At low resolution, JAX-CPU is faster than JAX-GPU. I haven't tested for numpy. I can mention this in the readme for now.

rmojgani commented 5 months ago

I mean for this reason, it is good to have it as a selectionable option, even for the runs

jwallwork23 commented 1 month ago

A colleague just told me about the Python "Array API" concept, which might be useful reading for this: https://data-apis.org/array-api/latest/purpose_and_scope.html#stakeholders