liesel-devs / liesel

A probabilistic programming framework
https://liesel-project.org
MIT License
38 stars 2 forks source link

Refactor GraphBuilder.transform() method #93

Open jobrachem opened 1 year ago

jobrachem commented 1 year ago

Notes from a discussion with @hriebl:

Currently, the GraphBuilder.transform() method covers two cases at the same time:

  1. Users explicitly supply a bijector that should be used for the transformation.
  2. Users do not explicitly supply a bijector; instead, the default bijector is obtained from the node's tensorflow distribution class.

The idea of this issue is to internally separate the functionality of the transform method into these two parts. The method's API would stay the same.

In the process, there is one nice thing we can achieve with this separation: For bijectors explicitly supplied by users, we can determine whether the model graph still requires edges going from the inputs of the original node's distribution to the original node, or whether it is enough to have the edges point to the transformed node's distribution. Consider the model graph below. In this graph, we could omit the edges from a and b to sigma, because a and b are only used in sigma_transformed after the transformation.

When we are automatically obtaining the default bijector based on the node's tensorflow distribution, there does not seem to be any way for us to know whether the mentioned edges are needed in the bijector or not. So, because they may be needed, and, on a more technical side, because we always need these inputs to initialize the distribution class to obtain the bijector in the first place, we cannot remove the edges in the second case.

Technical background:

In the internal helper function _transform_back, the input parameters of the original node's distribution are fed as inputs in the back-transformation calculator:

return Calc(fn, var_transformed.value_node, *inputs, **kwinputs)  # type: ignore

https://github.com/liesel-devs/liesel/blob/e3eec6e885173a4d4c7430d58e74abbd8ec4543b/liesel/model/model.py#L81C1-L81C1

hriebl commented 1 year ago

Thanks for the detailed write-up of our discussion! I'm not sure if improving the graph visualization is worth the effort, but if we decide to do this, I think splitting the GraphBuilder.transform() method is the way to go.

jobrachem commented 1 year ago

Maybe it will someday be tackled on a lazy Sunday afternoon... 😊