pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.13k stars 232 forks source link

Vectorized interpretation for pyro.scan #686

Open fritzo opened 4 years ago

fritzo commented 4 years ago

This proposes to implement a vectorized interpretation of pyro.scan that completely parallelizes over the time axis. This follows a hand-implementation of vectorization in pyro.contrib.epidemiology (in CompartmentalModel._relaxed_model() and ._vectorized_model. This interpretation works only under replay or condition (the standard interpretation of sampling cannot be time-parallelized), therefore this will require three-way interaction between pyro.scan, poutine.condition, and the new interpretation.

Here is a vague sketch of the parallel implementation of pyro.scan:

```python def vectorized_scan(transition, time, init): """ This assumes ``init`` is a dict mapping unindexed sample site name (i.e. "x" rather than "x_0") to value. TODO generalize to PyTree. This assumes ``time`` is a range object. TODO generalize to jnp.arange? """ # Trace the first step and assume model structure is fixed. # In pyro.contrib.epideiology we do this once at the start of inference. # Maybe we could memoize to avoid duplicated execution? with poutine.block(), poutine.trace() as tr: t = 0 transition(init, t) names = [name for name in tr.trace.stochastic_nodes if name.endswith("_0")] # The remainder is vectorized over time. with pyro.plate("time", len(time)): # or maybe jax.vmap t = slice(0, len(time), 1) # or maybe jnp.arange # Record vectorized values. curr = {} prev = {} with poutine.block_trace_but_allow_replay_and_condition(): for name in names: name_0 = "{}_{}".format(name, 0) name_t = "{}_{}".format(name, t) site_0 = tr.nodes[name_0] curr[name] = pyro.sample(name_t, site_0["fn"]) prev[name] = torch.cat([site_0["value"].unsqueeze(0), curr[name]]) # Execute vectorized transition. transition(prev, t) return ... ```

cc @fehiepsi @eb8680

fehiepsi commented 4 years ago

This is nice to have! If you need any functionality from jax/numpyro, just let me know. About block_trace_but_allow_replay_and_condition, probably it is simpler in NumPyro because you can use control_flow primitive (as in numpyro.contrib.control_flow.scan) and decide if substitute_stack is empty or not. If that stack provides all information to do vectorization, then you can run vectorized_scan.

fehiepsi commented 3 years ago

@fritzo Can I take a stab at this? :)

fritzo commented 3 years ago

@fehiepsi sure, and let me know if you want me to review any code.