probabilists / zuko

Normalizing flows in PyTorch
https://zuko.readthedocs.io
MIT License
333 stars 24 forks source link

It seems the `cdf()` in `DiagNormal` has not been implemented? #40

Closed Bill-Gots closed 9 months ago

Bill-Gots commented 9 months ago

Running the following lines

import zuko
model = zuko.flows.GF(3)
print(model.base().cdf(0.))

I got a NotImplementedError in torch\distributions\distribution.py.

Plus, I wanted to truncate the base in flows.GF, but I failed to manage it.

import zuko
model = zuko.flows.GF(3)
model.base = zuko.distributions.Truncated(model.base(), -1., 1.)
samples = model().rsample((32,))
print(samples)

Running the code above I got

File "D:\Program Files\Python3.11.4\Lib\site-packages\zuko\distributions.py", line 458, in __init__
    assert not base.event_shape, "'base' has to be univariate"
AssertionError: 'base' has to be univariate

I am new in normalization flows so I don't know how to correct this. Any help is appreciated.

francois-rozet commented 9 months ago

Hello, DiagNormal is a multivariate normal distribution with a diagonal covariance matrix. A multivariate distribution does not have a CDF, or rather there is no unique definition of the CDF of a multivariate distribution. What do you want to achieve exactly?

I guess that you want to define a multivariate truncated normal. zuko.distributions.Truncated works for univariate distributions, as the error message suggests. In your case, you should truncate univariate normal distributions, and then stack them as a single multivariate distribution. Note that the base attribute should be a lazy distribution.

class LazyDiagTruncatedNormal(zuko.flows.LazyDistribution):
    def __init__(self, features: int, lower: float = -1.0, upper: float = 1.0):
        super().__init__()

        self.register_buffer('loc', torch.zeros(features))
        self.register_buffer('scale', torch.ones(features))
        self.register_buffer('lower', torch.as_tensor(lower))
        self.register_buffer('upper', torch.as_tensor(upper))

    def forward(self, c: Tensor = None):
        return Independent(Truncated(Normal(self.loc, self.scale), self.lower, self.upper), 1)

model.base = LazyDiagTruncatedNormal(3)

I don't know why you want a truncated base distribution, but I should warn you that evaluating the log_prob of such flow will likely lead to NaNs / $-\infty$ as the probability outside of the interval is null ($\log 0$).

Bill-Gots commented 9 months ago

Hello, DiagNormal is a multivariate normal distribution with a diagonal covariance matrix. A multivariate distribution does not have a CDF, or rather there is no unique definition of the CDF of a multivariate distribution. What do you want to achieve exactly?

I guess that you want to define a multivariate truncated normal. zuko.distributions.Truncated works for univariate distributions, as the error message suggests. In your case, you should truncate univariate normal distributions, and then stack them as a single multivariate distribution. Note that the base attribute should be a lazy distribution.

class LazyDiagTruncatedNormal(zuko.flows.LazyDistribution):
    def __init__(self, features: int, lower: float = -1.0, upper: float = 1.0):
        super().__init__()

        self.register_buffer('loc', torch.zeros(features))
        self.register_buffer('scale', torch.ones(features))
        self.register_buffer('lower', torch.as_tensor(lower))
        self.register_buffer('upper', torch.as_tensor(upper))

    def forward(self, c: Tensor = None):
        return Independent(Truncated(Normal(self.loc, self.scale), self.lower, self.upper), 1)

model.base = LazyDiagTruncatedNormal(3)

I don't know why you want a truncated base distribution, but I should warn you that evaluating the log_prob of such flow will likely lead to NaNs / −∞ as the probability outside of the interval is null (log⁡0).

Thanks for your immediate reply! After reading the documentation, I'm still not very sure about the usage of many classes because I'm just learning normalization flows. What I want to achieve exactly is a snippet of initializing an NSF (or some other flows in zuko.flows) and then making it "truncated", so I want to add a "truncate transform" in that NSF. I tried to add SoftclipTransform in model.transforms first but I failed because the transform is not a nn.Module. Then I tried Truncated like above and as you could see I failed again.

I'm learning to use zuko but lack some fundamental knowledge, so I'm struggling to understand the usages. Like, why there are "LazyTransforms" (are there some transforms called "hard-working transforms"? just a joke)? So please forgive me for any misunderstanding caused. :)

francois-rozet commented 9 months ago

Like, why there are "LazyTransforms" (are there some transforms called "hard-working transforms"? just a joke)?

Have you checked the Learn the basics tutorial and in particular the parametrization section? It is explained there why LazyDistribution and LazyTransform are necessary, and what they do.

When designing the distributions module, the PyTorch team decided that distributions and transformations should be lightweight objects that are used as part of computations but destroyed afterwards. Consequently, the Distribution and Transform classes are not sub-classes of torch.nn.Module, which means that we cannot retrieve their parameters with .parameters(), send their internal tensor to GPU with .to('cuda') or train them as regular neural networks. In addition, the concepts of conditional distribution and transformation, which are essential for probabilistic inference, are impossible to express with the current interface.

To solve these problems, zuko defines two concepts: the LazyDistribution and the LazyTransform, which are modules whose forward pass returns a distribution or transformation, respectively. These components hold the parameters of the distributions/transformations as well as the recipe to build them, such that the actual distribution/transformation objects are lazily built and destroyed when necessary. Importantly, because the creation of the distribution/transformation object is delayed, an eventual condition can be easily taken into account. This design enables lazy distributions to act like distributions while retaining features inherent to modules, such as trainable parameters.

What I want to achieve exactly is a snippet of initializing an NSF [...] and then making it "truncated"

What do you mean by a "truncated flow"? The support of the distribution defined by the flow should be limited to some interval (hyper-cube)?