pyro-ppl / funsor

Functional tensors for probabilistic programming
https://funsor.pyro.ai
Apache License 2.0
236 stars 20 forks source link

Importance funsor #578

Open ordabayevy opened 3 years ago

ordabayevy commented 3 years ago

Importance sampling is represented by an Importance funsor.

  1. Signature - Importance(model, guide, sampled_vars).
  2. When guide is a Delta it eagerly evaluates to guide + model - guide.
  3. Importance.reduce is delegated to Importance.model.reduce.
  4. (not implemented) consider implementing MonteCarlo interpretation when guide is not a Delta.

Dice factor as an importance weight

model = Delta(name, point, log_prob)
guide = Delta(name, point, ops.detach(log_prob))
Importance(model, guide, name)
    == guide + model - guide
    == guide + log_prob - ops.detach(log_prob)
    == guide + dice_factor

Lazy interpretation

lazy_importance = DispatchedInterpretation("lazy_importance")

@lazy_importance.register(Importance, Funsor, Delta, frozenset)
def _lazy_importance(model, guide, sampled_vars):
    return reflect.interpret(Importance, model, guide, sampled_vars)

It is used for a lazy importance sampling:

with lazy_importance:
    sampled_dist = dist.sample(msg["name"], sample_inputs)

and for adjoint algorithm:

with lazy_importance:
    marginals = adjoint(funsor.ops.logaddexp, funsor.ops.add, logzq)

Separability

model_a = Delta(“a”, point_a[“i”], log_prob_a)
guide_a = Delta(“a”, point_a[“i”], ops.detach(log_prob_a))
q_a = Importance(model_a, guide_a, {“a”})

model_b = Delta(“b”, point_b[“i”], log_prob_b)
guide_b = Delta(“b”, point_b[“i”], ops.detach(log_prob_b))
q_b = Importance(model_b, guide_b, {“b”})

with lazy_importance:
    (q_a.exp() * q_b.exp() * cost_b).reduce(add, {“a”, “b”, “i”})
    == [q_a.exp().reduce(add, “a”) * (q_b.exp() * cost_b).reduce(add, {“b”})].reduce(add, “i”)
    == [1(“i”) * (q_b.exp().reduce(add, {“b”}) + cost_b(b=point_b))].reduce(add, “i”)
    == [1(“i”) * 1("i") * cost_b(b=point_b)].reduce(add, “i”)
    == cost_b(b=point_b).reduce(add, “i”)