pyro-ppl / funsor

Functional tensors for probabilistic programming
https://funsor.pyro.ai
Apache License 2.0
236 stars 20 forks source link

Autodiff implementation (experimental) #494

Open ordabayevy opened 3 years ago

ordabayevy commented 3 years ago

This is an implementation of autodiff. The goal is to address issues in computing expectations in TraceEnum_ELBO and TraceMarkovEnum_ELBO (#493). As of now it seems to fix nan gradients under eager interpretation in TraceEnum_ELBO.

The algorithm implements equivalents of linearize(), transpose() functions, and is tape-free (#446).

  1. Linearize. Variables that need to be linearized are replaced by primal- tangent tuple JVP(primal, tangent) and then pattern matched to propagate tangents, e.g.:
JVP(x, dx) + JVP(y, dy) = JVP(x+y, dx+dy)
JVP(x, dx) * JVP(y, dy) = JVP(x*y, ydx + xdy)
JVP(x, dx) * y = JVP(x*y, ydx)

Out tangent is a linear function of in tangents. JVP is used for (add,mul) semiring and LJVP is used for (logaddexp,add) semiring.

  1. Transpose of a linear function. Transpose is implemented simply by inverting the order of function execution and transposing matrices, in this case swapping more primitive operations .reduce(sum_op, "i") and .expand("i") (broadcasting does this automatically).