liesel-devs / liesel

A probabilistic programming framework
https://liesel-project.org
MIT License
40 stars 2 forks source link

Call Calc.update() once as the last step of initialization #79

Closed jobrachem closed 8 months ago

jobrachem commented 1 year ago

The following example works fine:

>>> my_calc2.value
>>> import liesel.model as lsl
>>> import jax.numpy as jnp
>>> 
>>> my_node = lsl.Data(1.0)
>>> my_calc = lsl.Calc(jnp.exp, my_node)
>>> my_calc.update()
Calc(name="")
>>> my_calc.value
Array(2.7182817, dtype=float32, weak_type=True)

However, the update does not work in the following case:

  1. There is a second level, i.e. annother calculator node that takes my_calc as its input,
  2. the value of my_calc has not been updated yet.

Demonstration:

>>> my_node = lsl.Data(1.0)
>>> my_calc = lsl.Calc(jnp.exp, my_node)
>>> my_calc2 = lsl.Calc(jnp.log, my_calc)
>>> my_calc2.update()
Traceback (most recent call last):
  [...]
  File ".../lib/python3.10/site-packages/jax/_src/numpy/util.py", line 344, in _check_arraylike
    raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: log requires ndarray or scalar arguments, got <class 'NoneType'> at position 0.

This sometimes makes debugging a little more tedious for me than I would like it to be. A current solution is to simply call .update() on the calculator nodes directly when initializing, as done in the code example directly above.

From my point of view, it would be convenient to always update calculators once upon their initilization. This would solve the issue. Coding problems in calculators would be apparent much more directly as a result, which would be really nice in model building, preventing hard-to-debug problems early on.

Is there something that speaks against a change like this?

jobrachem commented 1 year ago