pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.5k stars 982 forks source link

Make subsample_size a parameter of inference #236

Open null-a opened 6 years ago

null-a commented 6 years ago

Since batch_size describes something about how inference should proceed, I'm not sure it makes too much sense to specify it in the model. Instead, I think it would make more more sense as an argument passed to those algorithms that make use of sub sampling.

One immediate practical benefit of this would be that users wouldn't have to consider whether batch_size ought to be specified in the model or the guide or both.

WebPPL used be this way, but we changed it to support nested map_data. A better way to handle this in pyro might be to allow batch sizes to be specified using the names associated with map_data calls. For covenience, it might be useful to also allow a single anonymous batch size to be given, which I guess would apply to the outer most map_data only.

(Possibly interacts with #169.)

fritzo commented 6 years ago

I like the idea of a MinibatchPoutine. It moves us closer to Zinkov & Shan's Composing inference algorithms as probabilistic programs.

ngoodman commented 6 years ago

i think this (making batch_size a parameter of the inference call) is the right approach, because it really is a feature of specific inference algorithms, not models. (one could argue that it's ok to put this info in the guide, since the guide has "stuff for doing inference"; but i'd rather preserve the semantics where the guide is simply a family of normalized distributions, with same support as model, representing an importance distribution for the posterior.) by allowing batching functions to be provided to the inf alg, in addition to simple batch_size, we additionally solve some issues with control of batching (#169) and possibly data loading?

addressing specific irange / iarange calls with their own batch size might be tricky (are they named?) but having the default be a global batch_size seems like an ok first pass. (if we have to allow local batch_size args in the guides for advanced use, it's probably ok for now....)

note that moving batch_size to the inference alg, also makes it easier to use different batch sizes for different (interleaved) optimization passes. i can imagine wanting to do this sometimes.

there may be an interaction between this and #154 (which still needs to be considered...).

dustinvtran commented 6 years ago

+1. Both options are useful. For example in GANs, you really might want to treat your program as defining a minibatch such that forward passes correlate the samples in the way that batch normalization implicitly does, or in the way that improved GAN's minibatch layer explicitly does.

null-a commented 6 years ago

Note for posterity: The trickiness encountered in deciding how the subsample_size option of iarange should work (#328) feels like it adds to the case for viewing subsampling as an inference thing rather than a model thing. Roughly speaking, subsampling happens in the model and/or guide at present, and this makes for a fiddly coordination problem. If sub sampling were performed by inference, then it could just hand mini batches to both the model and guide, which seems more straight forward conceptually. (Though may not be straight forward to implement.)

fritzo commented 6 years ago

@null-a I agree and I like the idea of a SubsamplePoutine that adds information in site["infer"]["subsample"] (note that site["infer"] is temporarily renamed to site["baseline"] since we're only using it for baselines right now; I'd prefer to rename that to site["infer"]["baseline"] as our features grow). This seems very clean and easy to implement, and the approach generalizes to things like an EnumeratePoutine that forces enum_discrete at a single site (via site["infer"]["enumerate"]).

eb8680 commented 6 years ago

Wouldn't it be much simpler to always use the batch_size in the guide and have ReplayPoutine override batch_size in the model? We could add a warning if the batch_size in the model is different.

fritzo commented 6 years ago

@eb8680 What you describe is the current behavior as of #347.