Open ricardoV94 opened 1 year ago
I think we could do this by using a somewhat lower-level interface in nuts-rs. Right now we just call sample_parallel
, which takes care of multithreading of the different chains and gives us draws from any chain when they become available.
But we can also directly instantiate the individual chains, and call those in a pymc step-method.
I don't know about rust/Python interop, but I assume that refactoring the trace backend is the right first step regardless. I'll try to make another step in that direction as my next PR.
Would be neat if we could use nutpie as a step sampler with the other PyMC samplers (without killing performance too much)
Might justify refactoring the trace backend CC @michaelosthege
@aseyboldt any hints how one could start investigating this?