probcomp / Gen.jl

A general-purpose probabilistic programming system with programmable inference
https://gen.dev
Apache License 2.0
1.79k stars 160 forks source link

Integration of trace updates and parameter updates #136

Open marcoct opened 4 years ago

marcoct commented 4 years ago

Generative functions may have state, which is their trainable parameters, and any gradient accumulation state for those parameters. This state is not managed by Gen.

This has not been a major issue, since in our use cases we have not combined parameter updates with random choice updates. Parameter updates have been mainly used to train proposals, and random choice updates have been used with generative models for e.g. MCMC. However, algorithms like EM might use both.

The behavior of update and regenerate when the parameters have changed since when the trace was created, is not currently defined in the Gen documentation. Options include:

Some more thinking is needed for this. In particular, regarding algorithms that may conceivably use both update/regenerate and trainable parameters, like EM.

The current behavior of update and regenerate for the built-in modeling languages doesn't check if the parameters changed, so the weight returned may be incorrect (assuming this is not the desired behavior, which I don't think it should be as mentioned above).

alex-lew commented 4 years ago

Note that these methods should not return any ratio of weights involving new and old parameters -- the weight is always for a single parameter.

This is particularly interesting because it suggests parameters should behave differently from arguments. I had been thinking of parameters as a quality-of-life feature, easy to de-sugar into extra arguments + explicit maintenance of state (the user stores current parameter values and updates them when necessary). But arguments are meant to be used compositionally within a bigger model, and changes to arguments usually mean that random choices somewhere else have changed (justifying the weight-ratio behavior of update when given argdiffs). Parameters change for reasons having nothing to do with "random choices elsewhere."

FWIW, I like proposal (2) above!

marcoct commented 4 years ago

Yes, having thought more about various patterns and algorithms for learning models from incomplete data in Gen, I think we should go with (2).

One example use case that Gen needs to support is online Monte Carlo EM using MCMC, where we maintain a collection of traces and iterate between doing some inference in them using MCMC (starting from the previous collection of traces) and doing gradient-based parameter updates given complete data.

marcoct commented 4 years ago

One type of online Monte Carlo EM involves taking a collection of importance-weighted traces and re-weighting them for the new parameters. The proposal of (2) doesn't automate this, but it can be accomplished by using get_score() on the traces before the update and then using get_score() after the parameter update, and adding the difference to the (log) importance weight, for each trace.

alex-lew commented 4 years ago

One type of online Monte Carlo EM involves taking a collection of importance-weighted traces and re-weighting them for the new parameters. The proposal of (2) doesn't automate this, but it can be accomplished by using get_score() on the traces before the update and then using get_score() after the parameter update, and adding the difference to the (log) importance weight, for each trace.

This makes me wonder if there should be a function for updating a trace according to a generative function’s current parameters, and returning the corresponding SMC weight.

Maybe it’s best if update does this after all (incorporates parameter update into weight). Then, in order to avoid that being part of the weight, you would just do your update in two steps: one to update to the current parameters (and you could choose to throw away the weight), and one to update actual choices (or args).

From a performance perspective, this seems similar to option (2), which involves an implicit rescoring step anyway. If you want to avoid people accidentally calling mh on traces that have params that haven’t been updated yet, you could return two weights (params and choices/args). Alternatively, you could add a function that checks whether a trace is up to date in terms of params, and mh could call it to see if it needs to do two updates or one.