pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.59k stars 988 forks source link

Consolidating optimization interfaces #109

Closed eb8680 closed 7 years ago

eb8680 commented 7 years ago

As we build more advanced gradient estimators ( #84 ) and change particular algorithm interfaces to fit our needs (e.g. #74 and #32 and #104 ) we should make sure our abstractions are lightweight and modular enough to limit code/concept duplication without getting in the way of people working independently.

I propose that we move toward a structure more along the lines of optimization in webPPL, where there is a single generic Optimize method parameterized by an objective, an optimization algorithm, a model, and a guide, such that the method searches for the parameters that minimize the objective and updates their values in-place in the parameter store.

Losses

There are many such objectives in the literature and in other PPLs. Two broad categories that have gotten much attention recently are marginal likelihood estimators ( #41 ), which are themselves increasingly sophisticated special cases of alpha divergences ( #35 #91 ) and f-divergences, and integral probability metrics like MMD and Wasserstein-1 distance. These objectives are also used to evaluate and criticize models, and ideally the structure of Pyro would reflect this by using the same code for inference and evaluation.

The loss interface should also be capable of representing (discounted) Monte Carlo returns from a Gym RL environment so that Pyro can become a tool for RL research (this might mean a Pyro wrapper for Gym environments to make sure gradient estimators are correct, but that doesn't seem like a burden for the user?).

Gradients

Our stochastic computation graph++ gradient estimators (currently in #84 implemented only for a particular algorithm) should be expected to do all the bookkeeping and heavy lifting to minimize stochastic gradient variance under the hood. The gradient estimators should live in the distributions library, and be released along with the primitive distributions and a standard autograd API as a standalone contribution to the PyTorch ecosystem or even PyTorch core. The primitive distributions themselves should also have fancier gradient estimators built by their samplers, e.g. RSVI estimators for the Gamma distribution.

Optimizers

There are also a number of other optimization algorithms apart from simple first-order stochastic gradient methods like SGD and Adam. For example, although they're often not useful for deep networks because of storage constraints, SVRG/SAGA-type algorithms may be quite helpful when applied to some probabilistic models, especially with high-variance gradient estimators. Another recent algorithm with some especially promising experimental results is stochastic trust region optimization.

It seems reasonable to expect that these algorithms can all be used with the same uniform interface, and also be designed to be agnostic to the gradient estimator used. Pyro will also eventually contain (stochastic) gradient-based MCMC algorithms, and these should share almost all of their code with the first-order optimizers (e.g. by simply calling an additional add_noise method). Is the current pyro.optim interface able to accomodate these future needs?

Active and inactive parameters

To accommodate optimizing fancy losses without having to expose all intermediate traces (e.g. importance-weighted ELBo), we could also update the parameter store mechanism so that every time backward is called, the parameters that appear in the graph register themselves as active in the parameter store. The Optimize function could then query the parameter store, provide the parameters and their gradients to the optimization algorithm, and then declare them inactive after the algorithm has updated them. This should be done at the level of Optimize so that the algorithms themselves don't depend on the parameter store and can be reused for HMC and SGLD type algorithms.

jpchen commented 7 years ago

+1. I'm in support of a generic Optimize that's agnostic to the estimator, objective, etc. especially as we add more of them (and better ones discovered!). Ive thought about this in the past, I can add a code snippet next week of what (I think) it would look like and how it would interact. This also can help with the refactoring of the gradient calculation and the stepping #32 during optimization.

eb8680 commented 7 years ago

v0 implemented in #212, closing this in favor of more targeted discussions going forward