Unidata / MetPy

MetPy is a collection of tools in Python for reading, visualizing and performing calculations with weather data.
https://unidata.github.io/MetPy/
BSD 3-Clause "New" or "Revised" License
1.21k stars 408 forks source link

Integration with Python performance tools like JAX #3432

Open ThomasMGeo opened 3 months ago

ThomasMGeo commented 3 months ago

What should we add?

There are a few ways to speed up raw numpy calculations. I looked at two options:

  1. JAX, See documentation here
  2. Numba, documentation here

This was tested on a M2 MacBook Pro (so no NVIDIA GPU) but I didn't have any trouble installing either package. Overall for two basic calcualtions, I saw speed ups on the order of 3-10x. These results are just from a few hours of hacking, and not intended to be strict benchmarks.

My Take

JAX was much easier to use, and faster. It felt like much more 'drop in' as a numpy replacement than re-writing JIT'd functions that didn't support all of numpy's functionality. If I needed to write faster numpy code, it was straightforward doing so with JAX, with or without a GPU.

Future packages or workflows to consider:

CuPy, downside is this requires a CUDA enabled GPU Cython might be another option Multiprocessing?

Notebook

Simple test notebook is here.

Reference

No response

winash12 commented 3 months ago

How about adding a xarray + dask example ? How does a Jupyter notebook that includes parallel processing takes place ?

jthielen commented 3 months ago

Just to add this for the sake of reference: about a year and a half ago I did some experiments with a Numba-based re-implementation of MetPy's CAPE calculations, as shown here: https://github.com/jthielen/cumulonumba/blob/main/examples/cumulonumba_v_metpy_rough_test.ipynb. Key takeaways were that the speed up with Numba was substantial (by two or three orders of magnitude), but that the JIT compilation costs were not insignificant (and perhaps a deal breaker for some use cases). This was yet another factor favoring Cython over Numba for MetPy's purposes.

ThomasMGeo commented 3 months ago

Thanks for the add @jthielen ! Have you had the chance to mess around with JAX? I know your quite busy :)

But overall I agree that numba is not the solution

ThomasMGeo commented 3 months ago

@winash12 , do you have a specific problem in mind to solve with xarray/dask?

jthielen commented 3 months ago

@ThomasMGeo Only a little bit, and not in this context unfortunately! That being said, for some of the underlying array operations (intersection finding, fixed point iteration), my hunch is that a JAX-type approach (given its more functional way of doing things) requiring more refactoring than Numba would. I could be mistaken on that though too!

dopplershift commented 3 months ago

@winash12 Interoperability with Dask is one of the major technical areas we are focusing on at the moment.

winash12 commented 3 months ago

Regarding the cython usage how do you propose to take it forward ? Will the cython code be in python and converted to C code by the compiler or can we add C or C++ functions ? The second would need C makefiles for the build to go through plus modifications to the LDPATH etc. From an implementation perspective my question is are you planning to permit usage of cdef functions or purely def functions ? Looking at the implementation of scipy they have many classes that do use cdef functions,

If we are planning to use cdef functions then worthwhile to look at xtensor - https://xtensor.readthedocs.io/en/latest/

@ThomasMGeo As an example let us assume I want to calculate potential vorticity (PV) for 4 different times and the input data is present in a single netCDF file. Now can I do the calculation of the PV of the four different time instances in parallel ? Most definitely I can as they are mutually independent data snapshots. For that I need to use dask arrays if I am not mistaken. Last time I attended the con call I recall everyone agreeing that there isn't a notebook yet to do this.

winash12 commented 3 months ago

Actually looking at it again if all we want is a faster version of numpy then I question the need for cython. xtensor has a python wrapper which we can use - https://github.com/xtensor-stack/xtensor-python

https://xtensor.readthedocs.io/en/latest/numpy.html

dopplershift commented 3 months ago

@winash12 I really need to update our roadmap with this stuff (#1655) but the plan is to only update particular places in the code that are bottlenecks to doing calculations at scale--and to see what's slow through benchmarks. The top offender that comes to mind is CAPE/CIN, mostly due to moist_lapse(). CAPE/CIN is especially problematic because the nature of the calculation resists vectorization, so direct looping is the only option. Hence, we look at compiled solutions.

That does not imply we're looking at general solutions for a faster numpy. It is really important for ease of maintenance and contributions from the community that we stick to Python. The nature of @ThomasMGeo's investigation was really to look at how well people using tools like JAX or CuPy can pass data from those libraries (which are numpy-like) into MetPy and have things "just work". We have no plans to depend on them, however. The same can be said about our plans for supporting Dask--we want to make sure we facilitate workflows using Dask (like the one you described for multiple levels of PV analysis), but we will not be using Dask directly within MetPy.

Currently on the table are:

The leader is Cython due to how commonplace it is within the scientific Python ecosystem. Also, I am heavily interested in the ability to run Python (with MetPy) within web browsers, so any solution chosen needs to be amenable to WASM (Web Assembly), so that likely rules out Numba. Rust/C++ are included for completeness (Rust mainly because there's a lot of momentum there, but I'm unclear on the numpy integration story), but I'm 95% sure we're going down the Cython route.

leaver2000 commented 2 months ago

@Z-Richard mentioned I should drop my code into a public repo

The code has a single runtime dependency on numpy and requires Cython to compile the moist_lapse ODE. There is an notebook that pulls from the weatherbench2 zarr storage. The code needs to be compiled in a certain way to achieve code coverage on the compiled binary which slows things down quite a bit.

This is dcape

image