I was experimentng with jax sampling and instead of running successive calls to sample_smc_blackjax (which I thought would be useful to compare the posteriors and enable rhat-like diagnostics), I was wondering whether it would make sense to hard code a given number of (successive) jobs in the code itself (pymc_experimental/inference/smc/sampling.py). Does it make sense to build the kernel only once, then runs several inference_loop + arviz_from_particles?
Sure, it makes sense to have more than one instance/"chain". Then we would be able to apply diagnostics like r-hat and rankplots. Other diagnostics like ess doesn't make sense.
Hi there,
I was experimentng with jax sampling and instead of running successive calls to sample_smc_blackjax (which I thought would be useful to compare the posteriors and enable rhat-like diagnostics), I was wondering whether it would make sense to hard code a given number of (successive) jobs in the code itself (pymc_experimental/inference/smc/sampling.py). Does it make sense to build the kernel only once, then runs several inference_loop + arviz_from_particles?
Cheers, Vian