pymc-devs / pymc-experimental

https://pymc-experimental.readthedocs.io
Other
72 stars 46 forks source link

Fix JAX warnings in tests #307

Closed jessegrabowski closed 4 months ago

jessegrabowski commented 4 months ago

Closes #305

JAX warning mentioned https://github.com/pymc-devs/pymc-experimental/issues/305 seems to be caused by using multiple cores when sampling with pm.sample in tests. The JAX test appears to be over-eager, because it is issued even when JAX is not being used, as long as JAX has been imported.

This PR adds the warning to filterwarnings in the pyproject.toml, since it doesn't appear to be relevant to the tests its causing to fail.

jessegrabowski commented 4 months ago

I wanted to filter by module (jax._src.xla_bridge), but it didn't work. This does, so I'm going with it.

By far my worst work, but it fixes it issue.

I'd still like to move tests up one level, out of the project files.

ricardoV94 commented 4 months ago

I'd still like to move tests up one level, out of the project files.

Let's do that in a separate PR