Open kylejcaron opened 2 years ago
Hmm can you try to change this line?
c_ = pm.MutableData("c_", np.where(df.censored==1, df.event_time, np.inf)
Otherwise those NaN
in the graph will cause havok. What happes when you use pm.sample
, doesn't it also complain?
Hmm can you try to change this line?
c_ = pm.MutableData("c_", np.where(df.censored==1, df.event_time, np.inf)
Otherwise those
NaN
in the graph will cause havok. What happes when you usepm.sample
, doesn't it also complain?
just tested np.inf
and it didnt solve the issue - and using pm.sample
works fine with the original example and doesnt complain
Whats really weird is that if I add a hierarchical global mean to one of the parameters it relieves some of the behavior (but the sampling still looks pretty off)
for instance if I change this line from my first example
mu_log_k =pm.Normal("mu_log_k", 0.5, 0.1)
to
mu_log_k =pm.Normal("mu_log_k", 0.5, 0.1)
log_k = pm.Normal("log_k", mu_log_k, 0.5, dims="group")
then the sampling is improved but still seems problematic with the following user warning from arviz when plotting a trace (and it seems like there are still some chains that are impacted)
/Users/kcaron/.pyenv/versions/3.9.7/envs/pymc4_env/lib/python3.9/site-packages/arviz/stats/density_utils.py:491: UserWarning: Your data appears to have a single value or no finite values
warnings.warn("Your data appears to have a single value or no finite values")
It sounds like your model is very ill conditioned and Numpyro just does a slightly worse job than PyMC.
Are you sure that Seems to be correct np.random.lognormal
prior corresponds to the PyMC one?
Also how many divergences are you getting with pm.sample
?
It sounds like your model is very ill conditioned and Numpyro just does a slightly worse job than PyMC.
~Are you sure that
np.random.lognormal
prior corresponds to the PyMC one?~ Seems to be correctAlso how many divergences are you getting with
pm.sample
?
The model seems fine from what I know - there are no divergences, high effective sample sizes and rhat is pretty much all rhat=1.0
(I found 3/400 parameters that had r_hat=1.01
)
In that case my best guess is some numerical instabilties/ errors in the jax backend
Description of your problem
When using
pm.sampling_jax
withpm.Censored
, the sampled posterior contains the same value for all draws within each chain. so for instance if I had 4 chains and 1000 draws per chain, the issue is that I'd see output that appears as followsPlease provide a minimal, self-contained, and reproducible example.
Please provide the full traceback.
Not applicable - there is no raised error in this case, but the functionality definitely is not working as intended as far as I can tell.
Complete error traceback
```python [The complete error output here] ```Please provide any additional information below.
Versions and main components