pyro-ppl / funsor

Functional tensors for probabilistic programming
https://funsor.pyro.ai
Apache License 2.0
236 stars 20 forks source link

Delayed param in minipyro #533

Open ordabayevy opened 3 years ago

ordabayevy commented 3 years ago

This issue proposes to create delayed pyro.param by log_joint and then use funsor.adam.Adam to optimize parameters. This would allow to write optimization part in minipyro in a way that is more consistent with funsor style (if there is such a thing) and also make it backend-agnostic (with more supported backends in funsor.adam.Adam). Support for constrained variables (#502) might be useful here.

class log_joint(Messenger):
    ...
    def process_message(self, msg):
        if msg["type"] == "param":
            msg["value"] = funsor.Variable(msg["name"], msg["output"])

class SVI:
    ...
    def run(self, *args, **kwargs):
        ...
        loss = ...
        with funsor.montecarlo.MonteCarlo():
            with funsor.adam.Adam(**options):
                loss.reduce(ops.min)

(I have tried this out for a simple guide where pyro.param doesn't have any constraints and not nested in pyro.plate and it seems to work.)

fritzo commented 3 years ago

This is great idea! It's a near literal translation of "(stochastic) (gradient descent)".