Open hawkrobe opened 6 years ago
cc: @karalets, hes doing this as well and from our discussions, doing it with batch matrix multiply
Sounds like a good idea, but I suspect it might run into issues with broadcasting in nn.Module
s. You could try expanding parameters manually for a few different networks to see if they behave correctly.
I still think this is still an important feature to make random_module()
usable at scale, and from recent conversations with @jpchen, it sounds like a minimal version allowing batching is in progress. I can understand wanting to clean up the issues, and obviously enhancements like this aren't top priority, but could this be reopened until there's consensus about whether it'll be supported or not? (or if there's a temporary hack outside pyro to achieve the same efficiency?)
we can make this issue more concrete to track progress. based on various discussions, there seems to be two v0 ideas worth trying that support the same use case:
1) ability for random_module to be able to lift a batch nn and run a subset forward. the lifted nn would need to be a "batch nn" (all its linear operations need to use torch.bmm
) and on the forward call, would take an argument specifying which nns to run forward
2) take batch priors (priors for all the domains/nns) and takes an index tensor upon calling which creates a nn with parameters sampled from the indexed priors. so on each call, a new nn is constructed and sampled
(on phone right now, will add code examples later)
(1) this has a restriction that the user knows the nn structure ie the shape of the inputs a priori. i have a basic implementation of (2) on a branch that i can open for review after some cleanup
Distributions in pyro can generally take a tensor of params and return a corresponding tensor of samples, allowing models to be vectorized nicely over batches.
For many applications of
random_module()
(e.g. hierarchical models), however, there's currently no corresponding way of vectorizing: you obviously can't sample a single tensor ofk
nns (e.g. one for each domain) and index into them with a batch tensor.@jpchen and I were chatting about a way of handling this by basically extending the sampling method of
random_module()
with asample_size
option that internally creates one giant nn with an additionali
th dimension indexing which of the sampled nns params you want. This would allow you to simultaneously run a single tensor of data through the corresponding nns. From our conversation: