pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.53k stars 984 forks source link

Vectorized random_module() #503

Open hawkrobe opened 6 years ago

hawkrobe commented 6 years ago

Distributions in pyro can generally take a tensor of params and return a corresponding tensor of samples, allowing models to be vectorized nicely over batches.

For many applications of random_module() (e.g. hierarchical models), however, there's currently no corresponding way of vectorizing: you obviously can't sample a single tensor of k nns (e.g. one for each domain) and index into them with a batch tensor.

@jpchen and I were chatting about a way of handling this by basically extending the sampling method of random_module() with a sample_size option that internally creates one giant nn with an additional ith dimension indexing which of the sampled nns params you want. This would allow you to simultaneously run a single tensor of data through the corresponding nns. From our conversation:

i think a solution might be ... to parallelize on the layer level; ie have one giant nn that operates on matrices individually. this is possible because we have multivariate normal so we can sample all domain parameters altogether and also because the input and outputs are all the same size

jpchen commented 6 years ago

cc: @karalets, hes doing this as well and from our discussions, doing it with batch matrix multiply

eb8680 commented 6 years ago

Sounds like a good idea, but I suspect it might run into issues with broadcasting in nn.Modules. You could try expanding parameters manually for a few different networks to see if they behave correctly.

hawkrobe commented 6 years ago

I still think this is still an important feature to make random_module() usable at scale, and from recent conversations with @jpchen, it sounds like a minimal version allowing batching is in progress. I can understand wanting to clean up the issues, and obviously enhancements like this aren't top priority, but could this be reopened until there's consensus about whether it'll be supported or not? (or if there's a temporary hack outside pyro to achieve the same efficiency?)

jpchen commented 6 years ago

we can make this issue more concrete to track progress. based on various discussions, there seems to be two v0 ideas worth trying that support the same use case: 1) ability for random_module to be able to lift a batch nn and run a subset forward. the lifted nn would need to be a "batch nn" (all its linear operations need to use torch.bmm) and on the forward call, would take an argument specifying which nns to run forward 2) take batch priors (priors for all the domains/nns) and takes an index tensor upon calling which creates a nn with parameters sampled from the indexed priors. so on each call, a new nn is constructed and sampled

(on phone right now, will add code examples later)

(1) this has a restriction that the user knows the nn structure ie the shape of the inputs a priori. i have a basic implementation of (2) on a branch that i can open for review after some cleanup