This issue proposes to create delayed pyro.param by log_joint and then use funsor.adam.Adam to optimize parameters. This would allow to write optimization part in minipyro in a way that is more consistent with funsor style (if there is such a thing) and also make it backend-agnostic (with more supported backends in funsor.adam.Adam). Support for constrained variables (#502) might be useful here.
class log_joint(Messenger):
...
def process_message(self, msg):
if msg["type"] == "param":
msg["value"] = funsor.Variable(msg["name"], msg["output"])
class SVI:
...
def run(self, *args, **kwargs):
...
loss = ...
with funsor.montecarlo.MonteCarlo():
with funsor.adam.Adam(**options):
loss.reduce(ops.min)
(I have tried this out for a simple guide where pyro.param doesn't have any constraints and not nested in pyro.plate and it seems to work.)
This issue proposes to create delayed
pyro.param
bylog_joint
and then usefunsor.adam.Adam
to optimize parameters. This would allow to write optimization part inminipyro
in a way that is more consistent with funsor style (if there is such a thing) and also make it backend-agnostic (with more supported backends infunsor.adam.Adam
). Support for constrained variables (#502) might be useful here.(I have tried this out for a simple guide where
pyro.param
doesn't have any constraints and not nested inpyro.plate
and it seems to work.)