pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.12k stars 231 forks source link

Stress test utility for numpyro? #1833

Open SamuelBrand1 opened 1 month ago

SamuelBrand1 commented 1 month ago

Hi everyone,

I'm working on an ODE based model (solver from diffrax) for respiratory viruses fitted on past seasons of data.

At the moment, we occasionally have maximum iteration fails during NUTS warm-up. This is a bit of a classic problem with ODE models, but I was wondering if there was a stress test utility in numpyro which will sample parameters (in the unconstrained domain I guess) and record the parameters that cause a fail in the model log posterior density call (e.g. like this from LogDensityProblems.jl.

I had a look in the docs but I might well have missed something, apols in advance if I have.

fehiepsi commented 1 month ago

I think you can draw samples from prior and use the utility log_density to check which of them causes NaN. We don't have a utility for that but please feel free to submit a PR. The feature would be useful for many usage cases.

SamuelBrand1 commented 1 month ago

Cheers for the update.