Gathering my thoughts on how to implement nonparametrics in mcx. This needs to be thought through before freezing the internal API as it is likely to impact it substantially. Implementation can wait, but API design must be ready before first release.
Here is a purely speculative design:
@mcx.model
def DirichletProcess(alpha, base_measure):
sticks_cache = []
atoms_cache = []
@mcx.model
def sticks(k):
@mcx.model
def new_stick():
b <~ Beta(1, alpha)
sticks_cache.append(b)
return b
return jax.lax.cond(
k <~ len(sticks_cache),
lambda _: sticks_cache[:k],
lambda _: new_stick(),
None
)
@mcx.model
def atoms(k):
@mcx.model
def new_atom():
a <~ base_measure
atoms_cache.append(a)
return a
return jax.lax.cond(
k <~ len(atoms_cache),
lambda _: atoms_cache[k],
lambda _: new_atom(),
None
)
@mcx.model
def distribution():
k = 0
while True:
stick <~ sticks(k)
do_pick <~ Bernoulli(stick)
if do_pick:
return atom(k)
k += 1
return distribution
This is nice, but to see how this would impact MCX's design we need to answer the following questions:
What does it look like to sample from distribution?
What does it look like to sample from DirichletProcess?
What does the logpdf of ditribution look like?
What does the logpdf of DirichletProcess look like?
References
General reference on probabilistic programming and what a language would need to be able to implement non-parametrics (look for "HOPPL"):
DRAFT
Gathering my thoughts on how to implement nonparametrics in
mcx
. This needs to be thought through before freezing the internal API as it is likely to impact it substantially. Implementation can wait, but API design must be ready before first release.Here is a purely speculative design:
This is nice, but to see how this would impact MCX's design we need to answer the following questions:
distribution
?DirichletProcess
?ditribution
look like?DirichletProcess
look like?References
General reference on probabilistic programming and what a language would need to be able to implement non-parametrics (look for "HOPPL"):
General literature on nonparametrics:
Stochastic memoization:
Collection of samplers for infinite mixture models: https://github.com/tscholak/imm
Truncated Dirichlet Process in PyMC3: