probcomp / Gen.jl

A general-purpose probabilistic programming system with programmable inference
https://gen.dev
Apache License 2.0
1.79k stars 160 forks source link

Conditioning on deterministic functions of random choices #134

Closed dpmerrell closed 4 years ago

dpmerrell commented 4 years ago

I want to know if there is a "right way" to condition my model on observations which are deterministic functions of random choices.

For illustration, think of a simple Binomial model:

@gen function binomial_model()
    N = @trace(geometric(0.05), :N)
    p = @trace(beta(2,2), :p)
    xs = zeros(N)
    for i=1:N
        xs[i] = @trace(bernoulli(p), :x => i)
    end
    z = sum(xs)
    return z
end

Now, suppose we observe z and want to infer posteriors for N and p. (We don't really care about the actual random choices -- x => i.)

As the model is currently written, that inference is impossible -- and for a simple reason: z is not traced, and therefore will not have an address in the choicemap.

Of course, there are ways to work around this. I can implement a PointMass distribution and let z = @trace(pointmass(sum(xs)), :z). But if I use a naive sampling strategy (e.g., Gen's default generate method), then the vast majority of traces will have zero probability. (And if z were continuous, the chances of getting a trace with nonzero probability would vanish.)

This, of course, could be remedied by using a smarter sampling strategy that respects the constraints imposed by the observation.

Anyway, all of the solutions I come up with seem kind of hack-y. I'm wondering if there's some language feature I'm not aware of which handles this.

PS: It seems like, in the general case, Gen's generate function would require an internal constraint solver to handle these kinds of observations. So I'm guessing this kind of inference is just fundamentally difficult to do in an automated fashion.

PPS: I'm not sure whether raising an issue is the right way to ask these questions -- is there a forum somewhere? Thanks for your time :)

alex-lew commented 4 years ago

Hey @dpmerrell,

Great question. A few responses:

  1. Some kinds of deterministic functions are easier to handle than others. I've been experimenting with a DSL for 'distributions' (#103) that would allow you to write things like @dist shifted_poisson(min_value, mean_value) = poisson(mean_value - min_value) + min_value to define distributions that are deterministic transformations of other distributions. In a model, you could then write something like @trace(shifted_poisson(5, 10), :number_of_students) and later condition on the number of students (which is a deterministic function of an underlying Poisson random variable). But this only applies in limited cases.

  2. Hakaru (https://hakaru-dev.github.io) is designed to handle the sorts of problems you're talking about. Gen treats probabilistic programs as defining distributions over traces of random choices, but Hakaru treats them as defining distributions over return values (which may be deterministic functions of random choices). If your program returns a tuple (a, b), you can condition on a and get a distribution over b. Hakaru is basically automating the process of finding an implementation of Gen.generate. But as you note, this is not tractable in general, as it can require solving arbitrary constraints. As such, for many programs, Hakaru will simply report failure. The programming language constructs Hakaru supports are also more limited, so as to make the program analyses it carries out feasible. That said, the Hakaru team has been working on growing the language, and it now supports a limited class of programs with loops and arrays, which might make it suitable for expressing the "sum of coin flips" program you've written.

  3. Another approach to handling this issue is the one taken by Omega (https://github.com/zenna/Omega.jl or http://www.zenna.org/publications/icmlsoft.pdf), which is to relax the constraints in a systematic way so as to make sampling possible.

  4. Yet another approach is probabilistic logic programming, which can take advantage of Prolog-style constraint-solving during inference -- see, e.g., ProbLog (https://dtai.cs.kuleuven.be/problog/).

  5. Finally, Gen lets you customize the implementation of generate and other methods on a per-model basis. So one approach would be to use the PointMass distribution as you propose, but add the constraint-solving logic yourself. A custom generate for the sum-of-bernoullis model could, e.g., sample N' ~ geometric(0.5), set N = z + N', and randomly choose z of the Bernoulli flips to set to true (and leave the others false). Note that generate is not responsible for sampling from the posterior, just having some approach to sampling values that are consistent with the constraints. Once a custom implementation of generate (and/or other methods) is defined, Gen's inference library (importance sampling, MH, etc.) can be used on the model to obtain (approximate) posterior samples. Furthermore, models that call your customized generative function will automatically inherit its logic, because the default generate is implemented compositionally (if f calls g, then generate(f, ...) calls generate(g, ...)).

Admittedly, the user experience of customizing generate and other methods is somewhat clunky right now -- we're working on making that nicer. I think it'd also be an interesting research direction to see if the techniques developed by the Hakaru team could be used to extend the class of Gen models for which reasonable generate implementations can be synthesized automatically, to include some models with deterministic dependencies between traced variables.

Hope that helps!

marcoct commented 4 years ago

Hi @dpmerrell. I'll add one thing to what Alex said:

In Gen currently you can implement an approach that is similar to the approach used in (3.), by introducing a some "noise" to the sum, constraining the noisy version of the sum, and then annealing the noise down from a large value to a small value during e.g. MCMC. This won't represent the exact posterior you were looking for, but it may give close enough results, depending on the aims.

The modified model would be:

@gen function binomial_model(noise)
    N = @trace(geometric(0.05), :N)
    p = @trace(beta(2,2), :p)
    xs = zeros(N)
    for i=1:N
        xs[i] = @trace(bernoulli(p), :x => i)
    end
    z = sum(xs)
    noisy_z = @trace(normal(z, noise), :noisy_z)
    return z
end

A minimal inference program in this approach would be something like:

initial_noise = 10. 
final_noise = 0.001
num_iters = 1000
noise_schedule = final_noise .+ (initial_noise - final_noise) * exp.(-0.05 * (0:num_iters))
@assert noise_schedule[1] == initial_noise
@assert noise_schedule[end] == final_noise

using Gen: generate, choicemap, update, UnknownChange

function do_mcmc_moves(trace)
    trace, = mh(trace, select(:N, :x))
    trace, = mh(trace, select(:x))
    for i=1:trace[:N]-1
        trace, = mh(trace, select(:x => i, :x => (i+1))) # might reduce one and increase another
    end 
    trace, = mh(trace, select(:p))
    return trace
end

function do_inference(observed_z)

    trace, = Gen.generate(binomial_model, (initial_noise,), choicemap((:noisy_z, observed_z)))
    for iter=1:num_iters

        # do MCMC moves on various random choices, except for :noisy_z
        trace = do_mcmc_moves(trace) 

        # reduce the noise a bit
        trace, = update(trace, (noise_schedule[iter],), (UnknownChange(),), choicemap())

        z = get_retval(trace)
        println(z)
    end 

end
dpmerrell commented 4 years ago

@alex-lew Very helpful -- thank you!

For my own purposes, I'll look into customizing generate. And I'll most likely take the PointMass/"constrained mcmc" approach.

I agree it would be cool if some of Hakaru's constraint-solving/trace generation capabilities were ported into Gen. I'll comment in this issue if I have any bright ideas on that front.

@marcoct that's fascinating, I haven't been exposed to that class of inference methods before. I guess I'll have to read that paper Alex linked to.

marcoct commented 4 years ago

I tested the code I posted above on a small example (observed_z = 10), and it meets the constraint z = 10 within about 50-60 iterations, and maintains that constraint for the rest of the Markov chain.

More sophisticated choices of moves inside of do_mcmc_moves, in particular using the 4-argument variant of 'https://probcomp.github.io/Gen/dev/ref/inference/#Gen.metropolis_hastings, would much improve the mixing of the Markov chain beyond the vanilla MCMC that I put in there.

marcoct commented 4 years ago

I'm closing this because conditioning on the return value of arbitrary deterministic functions is in general an arbitrary expensive operation, and so allowing arbitrary deterministic expressions to be constrained by generate is unlikely to be ever supported in general in Gen. This is to ensure that running times of inference programs can be predictable from the running time of the generative function itself (generate for at least the built-in modeling languages should have about the same running time as simulate).

The most likely near-term feature enhancement that will address this issue is https://github.com/probcomp/Gen/pull/103, which makes it possible to construct new distributions by combining existing distributions with deterministic code. The resulting values can then be treated as random choices just like those from other distributions, and they can be constrained in generate.