pymc-devs / pymc-experimental

https://pymc-experimental.readthedocs.io
Other
75 stars 47 forks source link

Conjugate relationship marginalisation #358

Open theorashid opened 2 months ago

theorashid commented 2 months ago

cc: @larryshamalama

Implement conjugate relationships in PyMC via rewrites. Saves us working them out by hand.

For starters:

nimble has a list of possibilities that we can add.

ricardoV94 commented 2 months ago

What is the idea of conjugacy? I'm familiar with conjugate priors, but those are not the same as marginalization? Instead they provide closed form solutions for posteriors?

Triple ? means I may be missing the point :)

jessegrabowski commented 2 months ago

This issue made me think of this paper: https://arxiv.org/abs/2302.00564

But maybe this is thinking about fully conjugate models, not intermediate relationships?

theorashid commented 2 months ago

I was thinking we would follow nimble's conjugate (Gibbs) samplers (probably a better name for the issue). So any prior with a sampling (dependent) node (their language) which is conjugate can be rewritten.

But also we could have rewrites to take advantage of the properties of normals or exponential distributions. e.g. (borrowed from an old chat with numpyro devs)

PS. Just as an example, because I appreciate not many people around here have ever used nimble. Here's an example of the default MCMC config for a simple BUGS model.

pumpCode <- nimbleCode({ 
  # Define relationships between nodes
  for (i in 1:N){
      theta[i] ~ dgamma(alpha,beta)
      lambda[i] <- theta[i]*t[i]
      x[i] ~ dpois(lambda[i])
  }
  # Set priors
  alpha ~ dexp(1.0)
  beta ~ dgamma(0.1,1.0)
})
...
pumpMCMC <- buildMCMC(pumpModel)
## ===== Monitors =====
## thin = 1: alpha, beta
## ===== Samplers =====
## RW sampler (1)
##   - alpha
## conjugate sampler (11)
##   - beta
##   - theta[]  (10 elements)

So it uses conjugate relationships where possible (theta: gamma -> poisson, beta: gamma -> gamma), and everything else with the default non-conjugate RW sampler.

ricardoV94 commented 2 months ago

Okay that clarifies it. I just hadn't heard of conjugacy as marginalization. I am not sure how we should handle this, my best guess when thinking about this in the past was to define a conjugate step sampler that can take draws from (an arbitrary) posterior distribution that has closed form solution.

Perhaps the easiest is a helper find_conjugate_steps that would return the specialized step samplers and could be passed to pm.sample. API would look something like:

with pm.Model() as m:
  ...
  conjugate_steps = pmx.find_conjugate_steps()
  idata = pm.sample(step=conjugate_steps)

That would have a natural fallback when conjugate steps can't be found, and users can also exclude some if they don't like it.

Otherwise we would need to re-sketch the step sampler assignment logic that exists in PyMC, as that eagerly defines a variable to belong to a sampler if that variable is of a certain type (or if the model logp can be differentiated wrt to it), and doesn't really have a nice place for reasoning about the whole model (I could be wrong here)