f0uriest / quadax

Numerical quadrature with JAX
MIT License
39 stars 1 forks source link

N-D Quadrature #10

Open JGameCreation opened 1 month ago

JGameCreation commented 1 month ago

Hi, big fan of this repository! Are there any updates on n-dimensional quadrature? It would be very useful for me!

f0uriest commented 1 month ago

I've been giving this some thought lately. Initially I was planning to do something like nquad in scipy which is basically just a wrapper around recursive calls to quad. However, this sort of thing in jax ended up causing lots of issues with jit and AD due to the large number of local function definitions, and it also is likely super inefficient on GPU since it's almost entirely sequential. There is probably still some way to make it work but I haven't had time to play with it more, so would welcome contributions for that.

The other main way is "proper" nd quadrature using actual nd rules, rather than iterated 1d rules. I'm still reading up on the theory of this (it's mostly the same as 1d stuff but there's an additional issue of deciding which axis to split cells along, and many of the rules/algorithms are specific to a particular number of dimensions).

JGameCreation commented 1 month ago

Thank you for sharing your thoughts! I guess the main problem with the recursive approach is that you would vmap over conditional control flow which means that you always have to wait for the slowest computation to finish. Therefore, I think some kind of batched nd quadrature rule would be ideal. Maybe it would be sufficient to start with 2 and 3 dimensions?