pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.1k stars 227 forks source link

Do not unflatten unevaluated lazy properties. #1778

Closed tillahoffmann closed 3 months ago

tillahoffmann commented 3 months ago

The tree_flatten function uses self.__dict__.get(name) to obtain the data field with the given name. Unevaluated lazy properties do not appear in self.__dict__, and they are set to None when the representation is unflattened. Here is an example.

>>> import jax
>>> from jax import numpy as jnp
>>> import numpyro
>>> 
>>> 
>>> @jax.jit
>>> def f1(x):
...     return numpyro.distributions.MultivariateNormal(jnp.zeros(3), jnp.eye(3))
>>> 
>>> 
>>> @jax.jit
>>> def f2(x):
...     dist = numpyro.distributions.MultivariateNormal(jnp.zeros(3), jnp.eye(3))
...     dist.precision_matrix
...     return dist

>>> print(f1(0).precision_matrix)
None
>>> print(f2(0).precision_matrix)
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]

The changes in this PR only set the attribute on the reconstructed instance if the value is not None or if the attribute is not a lazy property. While there is ambiguity between None representing an unevaluated property and None being the value of a lazy property, the implementation remains correct: If the evaluated value is None it is not cached and re-evaluates to None the first time the attribute is accessed on the reconstructed instance.

I've also added a test.

fehiepsi commented 3 months ago

Hi @tillahoffmann, could you point me to an example where the value is None but it is not a lazy property?

tillahoffmann commented 3 months ago

I'm not actually aware of an example where the value is None. We could also just not set the value if it's None without the extra lazy_property check.

I added it to prevent surprises. E.g., if a user implemented a custom distribution where one of the fields has a None value and we didn't check the lazy_property, they would get an AttributeError if they accessed the attribute after a flatten/unflatten cycle. It's probably an unlikely scenario, however.

fehiepsi commented 3 months ago

Oh, I think I understand your implementation now. That makes sense to me. Thanks!!