Open fritzo opened 4 years ago
This is nice to have! If you need any functionality from jax/numpyro, just let me know. About block_trace_but_allow_replay_and_condition
, probably it is simpler in NumPyro because you can use control_flow
primitive (as in numpyro.contrib.control_flow.scan
) and decide if substitute_stack
is empty or not. If that stack provides all information to do vectorization, then you can run vectorized_scan
.
@fritzo Can I take a stab at this? :)
@fehiepsi sure, and let me know if you want me to review any code.
This proposes to implement a vectorized interpretation of
pyro.scan
that completely parallelizes over the time axis. This follows a hand-implementation of vectorization inpyro.contrib.epidemiology
(in CompartmentalModel._relaxed_model() and ._vectorized_model. This interpretation works only under replay or condition (the standard interpretation of sampling cannot be time-parallelized), therefore this will require three-way interaction betweenpyro.scan
,poutine.condition
, and the new interpretation.Here is a vague sketch of the parallel implementation of
pyro.scan
:cc @fehiepsi @eb8680