esa / torchquad

Numerical integration in arbitrary dimensions on the GPU using PyTorch / TF / JAX
https://www.esa.int/gsp/ACT/open_source/torchquad/
GNU General Public License v3.0
189 stars 40 forks source link

[jax] Gaussian and GaussLegendre throw errors #214

Open javier-garcia-tilburg opened 3 days ago

javier-garcia-tilburg commented 3 days ago

Issue

Problem Description

Gaussian and GaussLegendre integrators don't work with jax backend, they throw

TypeError: prod requires ndarray or scalar arguments, got <class 'list'> at position 0.

and this error comes from

https://github.com/esa/torchquad/blob/bf8ed5c1687e52d223cb8f55c0a99b2c8bbeaded/torchquad/integration/gaussian.py#L67-L69

What Needs to be Done

I propose to convert the list returned by anp.meshgrid into an array with anp.stack. This fix works for me in jax, however we should first check that this doesn't cause problems in other backends.

            return anp.prod(
                anp.stack(anp.meshgrid(*([weights] * dim), like=backend)), axis=0
            ).ravel()

How Can It Be Tested or Reproduced

Using jax-0.4.35 and torchquad 0.4.0 run

import jax
import jax.numpy as jnp
from torchquad import set_up_backend, MonteCarlo, Gaussian, GaussLegendre

set_up_backend(backend="jax")

@jax.jit
def some_function(x):
    return jnp.power(x[:, 0] - x[:, 1], 2)

g = Gaussian()
# It also fails with GaussLegendre
# g = GaussLegendre()

integral_value = g.integrate(
    lambda x: some_function(x),
    dim=2,
    N=10000,
    integration_domain=jnp.asarray([[-1.0, 1.0], [-1.0, 1.0]]),
)