danielward27 / flowjax

https://danielward27.github.io/flowjax/
MIT License
101 stars 14 forks source link

BlockNeuralAutoregressiveFlow not initially normalized correctly #102

Closed mdmould closed 1 year ago

mdmould commented 1 year ago

The log_prob evaluations of an untrained instance of BlockNeuralAutoregressiveFlow are not correctly normalized.

E.g.:

import jax
import jax.numpy as jnp
from flowjax.distributions import Normal
from flowjax.flows import BlockNeuralAutoregressiveFlow
import matplotlib.pyplot as plt

flow = BlockNeuralAutoregressiveFlow(
    key=jax.random.PRNGKey(0),
    base_dist=Normal(jnp.zeros(1)),
    cond_dim=None,
    nn_depth=1,
    nn_block_dim=1,
    flow_layers=1,
    invert=True,
    )

x = jnp.linspace(-100, 100, 100_000)
y = jnp.exp(flow.log_prob(x[:, None]))
# plt.plot(x, y); plt.show()
print(jnp.trapz(y, x))

The same is true if the support of the transformed distribution is explicitly bounded (this is unsurprising in hindsight, as the additional log-abs-det-Jacobian from the bounding bijection of course cannot account for any previous bijections), e.g.:

import jax
import jax.numpy as jnp
from flowjax.distributions import Normal, Transformed
from flowjax.bijections import Invert, Chain, Tanh, BlockAutoregressiveNetwork
import matplotlib.pyplot as plt

bijections = [
    Invert(
        BlockAutoregressiveNetwork(
            key=jax.random.PRNGKey(0),
            dim=1,
            cond_dim=None,
            depth=1,
            block_dim=1,
            ),
        ),
    Tanh(shape=(1,)),
    ]

flow = Transformed(Normal(jnp.zeros(1)), Chain(bijections))

x = jnp.linspace(-2, 2, 100_000)
y = jnp.exp(flow.log_prob(x[:, None]))
# plt.plot(x, y); plt.show()
print(jnp.trapz(y, x))

Is this expected behaviour (perhaps related to initialization; Appendix C of https://arxiv.org/abs/1904.04676)? This does not occur for other flows, e.g., MaskedAutoregressiveFlow. In practice it's not really an issue, because the correct normalization is preserved once the flow is trained to some samples.

danielward27 commented 1 year ago

Thanks for the report. I've had a quick look and I think I am starting to piece together what the issue is. The codomain of the BlockAutoregressiveNetwork bijection is not unbounded as the output is a linear transformation of a tanh transformed variable. This leads to an issue equivalent to below

import jax.numpy as jnp
from flowjax.distributions import Normal, Transformed
from flowjax.bijections import Invert, Tanh

dist = Transformed(Normal(), Invert(Tanh()))  # Arctanh domain [-1, 1]
x = jnp.linspace(-5, 5, 1000)
y = jnp.exp(dist.log_prob(x))
print(jnp.trapz(y, x)) # 0.68 i.e. integrate normal from -1 to 1

Generally, because we don't have explicit constraints, it is up to users to ensure that the bijection is valid across the support of the distribution. However, obviously with the flow/BlockAutoregressiveNetwork this is a bug. In practice, with training, the flow learns to utilise all (or nearly all) the mass from the base distribution, but upon initialisation that is not the case, leading to the issue you found. I'll have a look into it and see if other implementations do things differently.

mdmould commented 1 year ago

I see, thanks for looking into this. Are there other monotonic bijections that could be used in place of tanh that would lead to support over the entire distribution prior to training?

danielward27 commented 1 year ago

Sure, that was my first thought for the easiest way to solve it. They mentioned in the paper leaky relu could be used (no results for it though), I hacked in a leaky relu, and it does result in a normalised distribution, but, the performance is much worse than with tanh. I wonder if this issue exists in the code associated with the paper (here)? I don't see any obvious way they handle it

mdmould commented 1 year ago

One option then is to scale the tanh to increase the support of the inverse.

Perhaps adding the choice of activation function to BlockNeuralAutoregressiveFlow and helper functions to construct the required form of the activation for BlockAutoregressiveNetwork would be useful.

Very quick example which seems to work:

import jax
import jax.numpy as jnp
from flowjax.distributions import Normal, Transformed
from flowjax.bijections import Invert, BlockAutoregressiveNetwork
from flowjax.nn.block_autoregressive import _reshape_jacobian_to_3d
import matplotlib.pyplot as plt

def wrap_activation(activation, n_blocks):
    grad = jax.vmap(jax.grad(activation))
    ladj = lambda x: jnp.sum(jnp.log(jnp.abs(grad(x))), axis=-1, keepdims=True)
    return lambda x: (activation(x), _reshape_jacobian_to_3d(ladj(x), n_blocks))

key = jax.random.PRNGKey(0)
dim = 1
cond_dim = None
depth = 1
block_dim = 1

base_dist = Normal(jnp.zeros(1), jnp.ones(1))
x = jnp.linspace(-10, 10, 1_000)

for scale in jnp.arange(1, 11):
    activation = lambda x: jax.nn.tanh(x) * scale
    new_activation = wrap_activation(activation, dim)
    net = BlockAutoregressiveNetwork(key, dim, cond_dim, depth, block_dim, new_activation)
    bijection = Invert(net)
    flow = Transformed(base_dist, bijection)

    p = jnp.exp(flow.log_prob(x[:, None]))
    norm = jnp.trapz(p, x)
    plt.plot(x, p, label=f'{scale}: {norm}')

plt.legend()
plt.show()

However, this does not guarantee correct initial normalization for any given seed.

danielward27 commented 1 year ago

Thanks! Just some thoughts.

  1. There is no real reason why the user should have to wrap the activation, that should just be handled within BlockAutoregressiveNetwork (a mistake on my part). Given I doubt many people have constructed custom activation functions for the flow, I'm inclined to change this on main.
  2. Being able to pass callable activation is nice, but relying on automatic differentiation can be a bit problematic. e.g. with tanh we will run into numerical precision issues as we are computing log(abs(grad(tanh)(x))), as the grad(tanh) may be very close to zero. So some way to pass in custom log grad functions (e.g. by passing in a bijection object) is nice when we have more stable forms of the log gradient.
  3. Yes you can probably improve things by using scaling, but I thing a more robust solution would be to use something like the following activation:
def activation(x):
    y = jnp.tanh(x)
    return y + 0.01 * jnp.abs(y) * x

Notice that for very large or small values of x, jnp.abs(y) becomes approximately 1, so your gradient will be at least 0.01, so now the codomain of the activation is the real line (rather than the interval [-1,1]). That should ensure the distribution is normalised (note it may have heavy tails).

danielward27 commented 1 year ago

There is a branch here if you want to have a play around https://github.com/danielward27/flowjax/tree/bnaf_bug. You can pass callables (which uses autodif, like in your example) or bijections with shape () as activation functions to the flow.

I'm hesitating to add this into main at the moment as there are some relatively arbitrary decisions to make. For example, what activation function to use, and if we should train the activation function parameters if it has any (e.g. if the activation is defined as a Bijection or a callable module).

Numpyro seems to have the same issue, so hopefully they will have some ideas too.

mdmould commented 1 year ago

Thanks for investigating!

Yep, this issue is fundamental to the original implementation of BNAF due to the image of tanh being a subset of the real line, so other packages that implement it that way will have the same issue. It's more of a design issue than a bug.

In terms of activation functions, they still must be monotonic and invertible to ensure the map is bijective and thus has a unique inverse, which limits the choice somewhat. I think having a sensible default but letting the user customize, as in the new branch (with a warning in the docstring about the requirements on the activation function), is a good idea. I don't see any reason why activations with trainable parameters would be an issue, as long as they always respect the required bijectivity, as their parameters should be automatically picked up in the pytree.