probabilists / zuko

Normalizing flows in PyTorch
https://zuko.readthedocs.io
MIT License
274 stars 20 forks source link

Add Extrapolation to Bernstein Polynomial Flow #36

Closed MArpogaus closed 2 months ago

MArpogaus commented 5 months ago

Description

As discussed in #32, we now implemented linear extrapolation outside the bounds of the Bernstein Polynomial. The feature becomes active if linear=True.

Here is a simple Python script to visualize the resulting effect:

# %% Imports
import torch

from matplotlib import pyplot as plt
from zuko.transforms import BernsteinTransform

# %% Globals
M = 10
batch_size = 10
torch.manual_seed(1)
theta = torch.rand(size=(M,)) * 500  # creates a random parameter vector

# %% Sigmoid
bpoly = BernsteinTransform(theta=theta, linear=False)

x = torch.linspace(-15, 15, 2000)
y = bpoly(x)

adj = bpoly.log_abs_det_jacobian(x, y).detach()
J = torch.diag(torch.autograd.functional.jacobian(bpoly, x)).abs().log()

# %% Plot
fig, axs = plt.subplots(2, sharex=True)
fig.suptitle("Bernstein polynomial with Sigmoid")
axs[0].plot(x, y, label="Bernstein polynomial")
axs[0].scatter(
    torch.linspace(-10, 10, bpoly.order + 1),
    bpoly.theta.numpy().flatten(),
    label="Bernstein coefficients",
)
axs[0].legend()
axs[1].plot(x, adj, label="ladj")
# axs[1].scatter(
#     torch.linspace(-10, 10, bpoly.order),
#     bpoly.dtheta.numpy().flatten(),
#     label="dtheta",
# )
axs[1].plot(x, J, label="ladj (autograd)")
axs[1].legend()
fig.tight_layout()
fig.savefig("sigmoid.png")

sigmoid

# %% Extrapolataion
bpoly = BernsteinTransform(theta=theta, linear=True)

x = torch.linspace(-15, 15, 2000)
y = bpoly(x)

adj = bpoly.log_abs_det_jacobian(x, y).detach()
J = torch.diag(torch.autograd.functional.jacobian(bpoly, x)).abs().log()

# %% Plot

fig, axs = plt.subplots(2, sharex=True)
fig.suptitle("Bernstein polynomial with linear extrapolation")
axs[0].plot(x, y, label="Bernstein polynomial")
axs[0].scatter(
    torch.linspace(-10, 10, bpoly.order + 1),
    bpoly.theta.numpy().flatten(),
    label="Bernstein coefficients",
)
axs[0].legend()
axs[1].plot(x, adj, label="ladj")
# axs[1].scatter(
#     torch.linspace(-10, 10, bpoly.order),
#     bpoly.dtheta.numpy().flatten(),
#     label="dtheta",
# )
axs[1].plot(x, J, label="ladj (autograd)")
axs[1].legend()
fig.tight_layout()
fig.savefig("linear.png")

linear

This makes the BPF Implementation more robust to data laying outside the domain.of the Bernstein Polynomial,without the need for thenon linear sigmoid function.

@oduerr Do you have anything else to add?

Implementation

The implementation can be found in the bpf_extrapolation branch of my fork.

My changes specifically include

  1. Optional linear extrapolation in the call method
  2. Custom Implementation of log_abs_det_jacobian since the gradient dos not seam to pass through the torch.where statement and to improve numerical stability ind the sigmoidal case.
francois-rozet commented 5 months ago

Hello @MArpogaus, thanks for the feature request and PR! I am curious about the issues you had with gradients. Normally torch.where should work fine with autograd as it computes more or less cond * left + (1 - cond) * right, except if left or right is NaN. So my guess is that evaluating the polynomial outside of $[0, 1]$ returns a NaN, which breaks autograd.

A solution could be to cast $x$ to 0.5 when $x \not \in [0, 1]$ before evaluating the polynomial.

Also how did you handle the inverse transformation with the extrapolation? I guess you would also need to look whether $y$ is larger/smaller than the values at the boundaries.

MArpogaus commented 5 months ago

I am curious about the issues you had with gradients. Normally torch.where should work fine with autograd as it computes more or less cond * left + (1 - cond) * right, except if left or right is NaN. So my guess is that evaluating the polynomial outside of [0,1] returns a NaN, which breaks autograd.

A solution could be to cast x to 0.5 when x∉[0,1] before evaluating the polynomial.

I had a similar problem in my TF implementation, and casting to 0.5 helped. But I did not succeed with it here. But I might give it another try.

Is it preferred to use autograd over analytical Jacobians?

Additionally, in the sigmoidal case, we discovered numerical issues since the derivative of sigma (sigma * (a-sigma)) converges towards zero for +/- Inf:

sigma

Hence, we decided to add finfo(dtype).tiny to avoid Infs when calculating the logarithm.

Also how did you handle the inverse transformation with the extrapolation? I guess you would also need to look whether y is larger/smaller than the values at the boundaries.

Good, point! I did not think of the inverse yet, but of course, the bisection is limited to the bounds. Thanks for pointing this out, I'll look into this.

francois-rozet commented 5 months ago

Is it preferred to use autograd over analytical Jacobians?

An analytical Jacobian is fine (if it is exact), but if the output $y$ of the transformation is not differentiable, the transformation itself cannot be trained through gradient descent.

Can you provide a small example demonstrating the incorrect behavior?

Additionally, in the sigmoidal case, we discovered numerical issues since the derivative of sigma (sigma * (a-sigma)) converges towards zero for +/- Inf

I see, but is it problematic in practice? Neural network layers expect standardized features (zero mean, unit variance), so the features are never too large.

Anyway if you have a working implementation with the extrapolation, we can drop the sigmoid mapping.

MArpogaus commented 5 months ago

A solution could be to cast x to 0.5 when x∉[0,1] before evaluating the polynomial.

This actually solved the problem, my binary mask was wrong in the first place. I update the code accordingly and drop the analytical Jacobian in a recent commit.

Anyway if you have a working implementation with the extrapolation, we can drop the sigmoid mapping.

dropped it.

Also how did you handle the inverse transformation with the extrapolation? I guess you would also need to look whether is larger/smaller than the values at the boundaries.

I now added a _inverse function with inverse extrapolation.

extra_with_invers

MArpogaus commented 5 months ago

I have added an additional option to get a smooth transition into the extrapolation by enforcing the second order derivative to be zero on the bounds (b2869d4).

without_smooth_bounds

with_smooth_bounds

@oduerr what do you think? Do we need such an option?

I was thinking about adding an other optional augment to specify the codomain of the polynomial.

first minimal working example:

low = -3
high = 33

positive_fn = torch.exp
theta = torch.rand(size=(M,))  # creates a random parameter vector
theta_low = low # - theta[:1].exp() optionally allow flexible bounds
theta_high = high # + theta[-1:].exp()
diff = torch.nn.functional.softmax(theta, dim=-1) # use theta[1:-1] if flexible_bounds=True
diff *= (theta_high - theta_low) - 2 * bpoly.eps

torch.cat([torch.cumsum(torch.cat([theta_low, diff], dim=-1), dim=-1), theta_high], dim=-1)

This is especially usefully when chaining several Transforms or using a base distribution with limited support.

francois-rozet commented 5 months ago

I have added an additional option to get a smooth transition into the extrapolation by enforcing the second order derivative to be zero on the bounds.

Nice! I would actually enable this by default if it is not too expensive.

I was thinking about adding an other optional augment to specify the codomain of the polynomial.

Why not! However, instead of specifying the co-domain, it is probably better to fix the co-domain to [-B, B] such that stacking several Bernstein transformations is safe and the co-domain covers well the standard Gaussian distribution. Note that this would not trivially enable base distributions with a limited support (e.g. uniform) as the extrapolation could lead to out-of-support values.

Note that implementing the [-B, B] co-domain would allow to drop the SoftclipTransform between transformations that are currently necessary in BPF.

MArpogaus commented 4 months ago

Nice! I would actually enable this by default if it is not too expensive.

Ok, i enabled it by default in a recent commit.

Why not! However, instead of specifying the co-domain, it is probably better to fix the co-domain to [-B, B] such that stacking several Bernstein transformations is safe and the co-domain covers well the standard Gaussian distribution. Note that this would not trivially enable base distributions with a limited support (e.g. uniform) as the extrapolation could lead to out-of-support values.

I can add it like this for now, but I certainly need asymmetric bounds for my own models. Could you maybe roughly scat what has to be adjusted in order for Zuko to support this? Then can work on this in a separate PR.

Note that implementing the [-B, B] co-domain would allow to drop the SoftclipTransform between transformations that are currently necessary in BPF.

Thanks for the hint. I'll make it conditional, depending on a new keep_in_bounds attribute inside the Transformation.

MArpogaus commented 4 months ago

Thanks for the hint. I'll make it conditional, depending on a new keep_in_bounds attribute inside the Transformation.

This is not possible, as the transform is not initialized when entering the BPF constructor. We can either add an this as an argument to BPF, or just keep it in the Transform for now, for use in custom flows.

@francois-rozet What is your opinion?

MArpogaus commented 4 months ago

Just pushed a first draft: 13ead7d

francois-rozet commented 4 months ago

What I had in mind is to make the Bernstein transform such that it always (not as an option) maps $\pm B$ to $\pm B$. Like this you can stack many transformations without any risk of collapse. This is important because the co-domain should always cover the base distribution support.

This is what the rational quadratic spline transform does. It maps B to B and extrapolates linearly outside. In addition, the derivatives at the bounds are 1 so that $y = x$ when $x$ is outside of the bounds.

Note that if you need another (asymmetric) domain or co-domain you can always combine a transformation with a (monotonic) affine transformation.

What do you think?

MArpogaus commented 4 months ago

Ok, that is basically what i implemented, but i did not yet enable it by default.

What is your opinion on chaining flexible transformations? I normally consider this an unnecessary increase in complexity and prefer single layer models with just one very flexible polynomial. But its certainly nice to have the ability of chaining them.

Currently the _constrain_theta takes the argument bound and transforms the coefficients accordingly.

If we want we could also force the derivative to 1 on the bound to ensure we get the identity function when extrapolating.

Could look sth like this:

    def _constrain_theta(unconstrained_theta: Tensor, bound: float) -> Tensor:
        r"""Processes the unconstrained output of the hyper-network to be increasing."""

        if bound:
            theta_min = -bound * torch.ones_like(unconstrained_theta[..., :1])

            def fn(x):
                return torch.cat(
                     (
                        torch.ones_like(unconstrained_theta[..., :2]), #  f'(0) = 1, f''(0)=0
                        torch.nn.functional.softmax(x, dim=-1) * (2 * bound - 4),
                        torch.ones_like(unconstrained_theta[..., :2])  #  f'(1) = 1, f''(1)=0
                     ),
                    dim = -1
                 )
        else:
            shift = math.log(2.0) * unconstrained_theta.shape[-1] / 2
            theta_min = unconstrained_theta[..., :1] - shift
            unconstrained_theta = unconstrained_theta[..., 1:]
            fn = torch.nn.functional.softplus

        widths = fn(unconstrained_theta)

        widths = torch.cat((theta_min, widths), dim=-1)
        theta = torch.cumsum(widths, dim=-1)

        return theta

However, i would personally prefer to keep an option for a simple "ordered" constrain function in the library, as we had in the beginning for less restrictive setups.. @oduerr What is your opinion?

francois-rozet commented 4 months ago

What is your opinion on chaining flexible transformations? I normally consider this an unnecessary increase in complexity and prefer single layer models with just one very flexible polynomial.

In multivariate flows, it is generally necessary to chain several multi-variate (think autoregressive or coupling) transformations, even if the univariate transformations are very expressive. And even if a single multi-variate transformation is used, you should ensure that its co-domain covers the support of the base distribution.

However, i would personally prefer to keep an option for a simple "ordered" constrain function in the library, as we had in the beginning for less restrictive setups..

You mean the one currently in the lib or one with extrapolation + smooth bounds?

Anyway, instead of adding more and more options to the class you can also sub-class it (e.g. BoundedBernsteinTransform).

oduerr commented 3 months ago

Sorry for the late reply.

We (@MArpogaus and I) just had a discussion and are a bit puzzled by:

In multivariate flows, it is generally necessary to chain several multi-variate (think autoregressive or coupling) transformations, even if the univariate transformations are very expressive.

We thought that if the 1-D transformation function is flexible enough, utilizing the AR flows allows us to express even complex distributions due to the chain rule of prob. See also [1]. While we still believe this is true in theory, we observed much better performance on the UCI Benchmarks when we did chaining. Do you have intuition why one gets a better performance when chaining, or do you even know a paper?

@MArpogaus will implement two classes as you suggested soon (in the following week): one unbounded and one Bounded in which the coefficients are scaled. Sorry for the delay.


[1] G. Papamakarios, E. Nalisnick, D. J. Rezende, S. Mohamed, and B. Lakshminarayanan, “Normalizing Flows for Probabilistic Modeling and Inference,” Journal of Machine Learning Research, vol. 22, no. 57, pp. 1–64, 2021.

francois-rozet commented 3 months ago

Hello @oduerr, no problem, everyone is busy :smile:

We thought that if the 1-D transformation function is flexible enough, utilizing the AR flows allows us to express even complex distributions due to the chain rule of prob. See also [1]. While we still believe this is true in theory, we observed much better performance on the UCI Benchmarks when we did chaining. Do you have intuition why one gets a better performance when chaining, or do you even know a paper?

You are right, a single auto-regressive transformation should be enough if the uni-variate transformation is an universal (monotonic) function approximator. However, the hyper network (MADE/MaskedMLP) conditioning the transformation must also be a universal function approximator. In practice, the capacity of the hyper network is finite and the function it should approximate might be complex depending on the data, the uni-variate function (and its parametrization) and the auto-regressive order. The latter, in particular, can have a huge impact: some orders can lead to very simple functions while others can lead to almost unlearnable functions. Stacking several multi-variate transformations with different (sometimes randomized) orders help to alleviate this issue.

For example if $p(x_1, x_2, ..., x_n) = p(x_1) \prod p(x_i | x_1)$ (star-shaped dependency graph), the order $1 \to n$ leads to a simple function, while the order $n \to 1$ does not. This is somewhat related to finding the Bayesian network (a directed acyclic graph) with the least number of edges that explains the data. In fact, the "simplest" Bayesian network is sometimes confused with the causal graph of the data.

oduerr commented 3 months ago

Hello @francois-rozet,

thanks for your patience. And many thanks for the intuitive example that makes it very clear!

MArpogaus commented 3 months ago

Hello everybody,

I have finally found the time to continue working on this. Here is what i got so far:

There are now two alternative versions of the Bernstein Transform:

The class BernsteinTransform is similar to the previous implementation, but linearly extrapolates outside the bounds and applies the "smooth bound constrain" to the coefficients, discussed above. This is not configurable anymore, as we agreed to implement different behavior via sub classing. However, i don't have a strong opinion here, if you want we could easily add a conditional for that.

bpoly_extra_with_invers

Additionally BoundedBernsteinTransform is an alternative implementation optimized for chaining. There are two major differences:

bounded_bpoly_extra_with_invers

For now BPF uses the bounded Version.

What are your opinions?

I would like to discuss two remaining questions:

francois-rozet commented 3 months ago

This looks really nice! I am curious to see how it compares to the other flows, especially NSF.

Should we provide an alternative BPF implementation based on the previous attempt with soft clip and fewer constraints on the parameters?

Unless you think it is necessary, I am fine with (and actually prefer) only proposing BPF with BoundedBernsteinTransform. Softclip usually deteriorates the performances quite a lot.

Should we provide a univariate version using the unbounded polynomial?

I don't think so. Users can construct their own flow with BernsteinTransform if necessary.

Flow(
    transform=ElementWiseTransform(
        features=1,
        ...
        univariate=BernsteinTransform,
        shape=[(16,)],
    ),
    base=...,
)
MArpogaus commented 3 months ago

Hey @francois-rozet, thanks for the fast response!

I am fine with it too. How about you @oduerr , dos this also meet your requirements?

oduerr commented 3 months ago

Hello @Marcel,

Thanks a lot for the changes; they look fine. The idea of fixing the coefficients for a bounded transformation is quite clever! There is no need from my side for an unbounded flow as I can roll my own easily, as described in https://github.com/probabilists/zuko/issues/36#issuecomment-2022804132

MArpogaus commented 3 months ago

This looks really nice! I am curious to see how it compares to the other flows, especially NSF.

some early results on "checkerboard" data, with default BPF params:

Dataset

dataset

Samples

samples

Evolution through training

ani

Density

density density3d

@francois-rozet do you have some results from other flows that you can share?

francois-rozet commented 3 months ago

Hello @MArpogaus, here are the results for NSF with basic settings.

You can check the run at https://wandb.ai/francois-rozet/zuko-benchmark-2d/runs/5xllo4bz .

francois-rozet commented 2 months ago

Closing as #37 was merged :partying_face: