pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.47k stars 1.97k forks source link

Refactor/Upgrade `find_MAP` #7308

Open jessegrabowski opened 1 month ago

jessegrabowski commented 1 month ago

Description

find_MAP is a bit out of style, but it's still a useful bit of functionality in PyMC that could use some love. In particular, I'm thinking about the following upgrades:

  1. Allow more free access to the underlying scipy.optimize.minimize configuration

Currently, the user isn't allowed to freely specify if he wants gradients or not -- it's chosen automatically based on their availability. This is a fine heuristic, but also needless. More importantly, if gradients are not available, the method requested by the user is completely overridden to be powell. This should still be a free choice, and should error if the user asks for a gradient-based solver.

  1. Allow second-derivative information

All models can compile a d2logp function, which can be exploited by several scipy.optimize.minimize routines, including Newton-CG, trust-ncg, trust-krylov, and trust-exact. There's no reason not to (optionally) allow users to bring in this information. For small problems the current d2lopg function (which computes the full dense hessian) should be fine. For performance reasons, though, we should also consider compiling a hessp function that returns the JVP of the gradients at a given vector. This is all that is needed by Newton-CG, trust-ncg, and trust-krylov (only trust-exact requires the full dense hessian).

  1. Allow compilation to alternate back-ends

The following code almost works, but not quite:

with pytensor.config.change_flags(dict(mode = 'JAX')):
    res = pm.find_MAP()

This is obviously desirable code in many cases (scan based models, looking at you). But it would also be useful in other cases, for example if we just really want to run our model on GPU. I suggest an API like pm.find_MAP(backend='JAX') that will handle all this for the user.

  1. Allow minibatch support with stochastic optimization

Admitted more of a stretch goal, but for memory-bound models it would be nice if we could get interoperability with the existing mini batch+stochastic optimization framework that already exists within PyMC. There's basically nothing standing in the way of this as far as I can tell. find_MAP already has a method method, this could take either a string (accessing a scipy optimizer) or a pymc stochastic optimizer (like pm.adam), which would trigger the appropriate machinery.

  1. Allow access to scipy.optimize.basinhopping

This is at the bottom because probably nobody but me cares, but I find simulated annealing to be quite robust to complex problems, and a good nuclear bomb to throw at really tricky optimization problems. It's a separate API from scipy.optimize.minimize, but we could hook it up in the background via the method argument. This is what statsmodels does, for example.

ricardoV94 commented 1 month ago

Anything that makes it more useful sounds good from my part