rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
324 stars 17 forks source link

Nonparametrics #17

Closed rlouf closed 3 years ago

rlouf commented 4 years ago

DRAFT

Gathering my thoughts on how to implement nonparametrics in mcx. This needs to be thought through before freezing the internal API as it is likely to impact it substantially. Implementation can wait, but API design must be ready before first release.

Here is a purely speculative design:

@mcx.model
def DirichletProcess(alpha, base_measure):

    sticks_cache = []
    atoms_cache = []

    @mcx.model
    def sticks(k):

        @mcx.model
        def new_stick():
            b <~ Beta(1, alpha)
            sticks_cache.append(b)
            return b

        return jax.lax.cond(
            k <~ len(sticks_cache),
            lambda _: sticks_cache[:k],
            lambda _: new_stick(),
            None
        )

    @mcx.model
    def atoms(k):

        @mcx.model
        def new_atom():
            a <~ base_measure
            atoms_cache.append(a)
            return a

        return jax.lax.cond(
            k <~ len(atoms_cache),
            lambda _: atoms_cache[k],
            lambda _: new_atom(),
            None
        )

    @mcx.model
    def distribution():
        k = 0
        while True:
            stick <~ sticks(k)
            do_pick <~ Bernoulli(stick)
            if do_pick:
                return atom(k)
            k += 1

    return distribution

This is nice, but to see how this would impact MCX's design we need to answer the following questions:

  1. What does it look like to sample from distribution?
  2. What does it look like to sample from DirichletProcess?
  3. What does the logpdf of ditribution look like?
  4. What does the logpdf of DirichletProcess look like?

References

General reference on probabilistic programming and what a language would need to be able to implement non-parametrics (look for "HOPPL"):

General literature on nonparametrics:

Stochastic memoization:

Collection of samplers for infinite mixture models: https://github.com/tscholak/imm

Truncated Dirichlet Process in PyMC3:

rlouf commented 3 years ago

Discussion moved to #94