Open javier-garcia-tilburg opened 3 days ago
Gaussian and GaussLegendre integrators don't work with jax backend, they throw
Gaussian
GaussLegendre
TypeError: prod requires ndarray or scalar arguments, got <class 'list'> at position 0.
and this error comes from
https://github.com/esa/torchquad/blob/bf8ed5c1687e52d223cb8f55c0a99b2c8bbeaded/torchquad/integration/gaussian.py#L67-L69
I propose to convert the list returned by anp.meshgrid into an array with anp.stack. This fix works for me in jax, however we should first check that this doesn't cause problems in other backends.
anp.meshgrid
anp.stack
return anp.prod( anp.stack(anp.meshgrid(*([weights] * dim), like=backend)), axis=0 ).ravel()
Using jax-0.4.35 and torchquad 0.4.0 run
import jax import jax.numpy as jnp from torchquad import set_up_backend, MonteCarlo, Gaussian, GaussLegendre set_up_backend(backend="jax") @jax.jit def some_function(x): return jnp.power(x[:, 0] - x[:, 1], 2) g = Gaussian() # It also fails with GaussLegendre # g = GaussLegendre() integral_value = g.integrate( lambda x: some_function(x), dim=2, N=10000, integration_domain=jnp.asarray([[-1.0, 1.0], [-1.0, 1.0]]), )
Issue
Problem Description
Gaussian
andGaussLegendre
integrators don't work with jax backend, they throwand this error comes from
https://github.com/esa/torchquad/blob/bf8ed5c1687e52d223cb8f55c0a99b2c8bbeaded/torchquad/integration/gaussian.py#L67-L69
What Needs to be Done
I propose to convert the list returned by
anp.meshgrid
into an array withanp.stack
. This fix works for me in jax, however we should first check that this doesn't cause problems in other backends.How Can It Be Tested or Reproduced
Using jax-0.4.35 and torchquad 0.4.0 run