f0uriest / quadax

Numerical quadrature with JAX
MIT License
44 stars 1 forks source link

Error in Readme example #2

Closed twhentschel closed 10 months ago

twhentschel commented 11 months ago

Hi @f0uriest,

Thanks for creating this neat package! For the example on the Readme, the assert statement is failing for me. The fix was to make JAX work with 64 bit precession with the lines

from jax import config
config.update("jax_enable_x64", True)

An alternative would be to remove the first assert statement and note that by default JAX using 32bit for floats. The second check works just fine but checking the precision out to 1e-14 might be too strict for other examples.

If you think this might be a bug, I'd be happy to contribute a PR for either of these fixes if you're interested.

f0uriest commented 10 months ago

Thanks for the catch!

I pretty much always use 64 bit but I know a lot of people don't so I've reduced the tolerances on the example and also made the default epsilon values vary with the dtype so it should work well for both 32 and 64 bit precision.