Closed mdmould closed 1 year ago
Thanks for the report. I've had a quick look and I think I am starting to piece together what the issue is. The codomain of the BlockAutoregressiveNetwork
bijection is not unbounded as the output is a linear transformation of a tanh transformed variable. This leads to an issue equivalent to below
import jax.numpy as jnp
from flowjax.distributions import Normal, Transformed
from flowjax.bijections import Invert, Tanh
dist = Transformed(Normal(), Invert(Tanh())) # Arctanh domain [-1, 1]
x = jnp.linspace(-5, 5, 1000)
y = jnp.exp(dist.log_prob(x))
print(jnp.trapz(y, x)) # 0.68 i.e. integrate normal from -1 to 1
Generally, because we don't have explicit constraints, it is up to users to ensure that the bijection is valid across the support of the distribution. However, obviously with the flow/BlockAutoregressiveNetwork this is a bug. In practice, with training, the flow learns to utilise all (or nearly all) the mass from the base distribution, but upon initialisation that is not the case, leading to the issue you found. I'll have a look into it and see if other implementations do things differently.
I see, thanks for looking into this. Are there other monotonic bijections that could be used in place of tanh that would lead to support over the entire distribution prior to training?
Sure, that was my first thought for the easiest way to solve it. They mentioned in the paper leaky relu could be used (no results for it though), I hacked in a leaky relu, and it does result in a normalised distribution, but, the performance is much worse than with tanh. I wonder if this issue exists in the code associated with the paper (here)? I don't see any obvious way they handle it
One option then is to scale the tanh to increase the support of the inverse.
Perhaps adding the choice of activation function to BlockNeuralAutoregressiveFlow
and helper functions to construct the required form of the activation for BlockAutoregressiveNetwork
would be useful.
Very quick example which seems to work:
import jax
import jax.numpy as jnp
from flowjax.distributions import Normal, Transformed
from flowjax.bijections import Invert, BlockAutoregressiveNetwork
from flowjax.nn.block_autoregressive import _reshape_jacobian_to_3d
import matplotlib.pyplot as plt
def wrap_activation(activation, n_blocks):
grad = jax.vmap(jax.grad(activation))
ladj = lambda x: jnp.sum(jnp.log(jnp.abs(grad(x))), axis=-1, keepdims=True)
return lambda x: (activation(x), _reshape_jacobian_to_3d(ladj(x), n_blocks))
key = jax.random.PRNGKey(0)
dim = 1
cond_dim = None
depth = 1
block_dim = 1
base_dist = Normal(jnp.zeros(1), jnp.ones(1))
x = jnp.linspace(-10, 10, 1_000)
for scale in jnp.arange(1, 11):
activation = lambda x: jax.nn.tanh(x) * scale
new_activation = wrap_activation(activation, dim)
net = BlockAutoregressiveNetwork(key, dim, cond_dim, depth, block_dim, new_activation)
bijection = Invert(net)
flow = Transformed(base_dist, bijection)
p = jnp.exp(flow.log_prob(x[:, None]))
norm = jnp.trapz(p, x)
plt.plot(x, p, label=f'{scale}: {norm}')
plt.legend()
plt.show()
However, this does not guarantee correct initial normalization for any given seed.
Thanks! Just some thoughts.
BlockAutoregressiveNetwork
(a mistake on my part). Given I doubt many people have constructed custom activation functions for the flow, I'm inclined to change this on main.log(abs(grad(tanh)(x)))
, as the grad(tanh)
may be very close to zero. So some way to pass in custom log grad functions (e.g. by passing in a bijection object) is nice when we have more stable forms of the log gradient.def activation(x):
y = jnp.tanh(x)
return y + 0.01 * jnp.abs(y) * x
Notice that for very large or small values of x, jnp.abs(y)
becomes approximately 1, so your gradient will be at least 0.01, so now the codomain of the activation is the real line (rather than the interval [-1,1]). That should ensure the distribution is normalised (note it may have heavy tails).
There is a branch here if you want to have a play around https://github.com/danielward27/flowjax/tree/bnaf_bug. You can pass callables (which uses autodif, like in your example) or bijections with shape () as activation functions to the flow.
I'm hesitating to add this into main at the moment as there are some relatively arbitrary decisions to make. For example, what activation function to use, and if we should train the activation function parameters if it has any (e.g. if the activation is defined as a Bijection or a callable module).
Numpyro seems to have the same issue, so hopefully they will have some ideas too.
Thanks for investigating!
Yep, this issue is fundamental to the original implementation of BNAF due to the image of tanh being a subset of the real line, so other packages that implement it that way will have the same issue. It's more of a design issue than a bug.
In terms of activation functions, they still must be monotonic and invertible to ensure the map is bijective and thus has a unique inverse, which limits the choice somewhat. I think having a sensible default but letting the user customize, as in the new branch (with a warning in the docstring about the requirements on the activation function), is a good idea. I don't see any reason why activations with trainable parameters would be an issue, as long as they always respect the required bijectivity, as their parameters should be automatically picked up in the pytree.
The
log_prob
evaluations of an untrained instance ofBlockNeuralAutoregressiveFlow
are not correctly normalized.E.g.:
The same is true if the support of the transformed distribution is explicitly bounded (this is unsurprising in hindsight, as the additional log-abs-det-Jacobian from the bounding bijection of course cannot account for any previous bijections), e.g.:
Is this expected behaviour (perhaps related to initialization; Appendix C of https://arxiv.org/abs/1904.04676)? This does not occur for other flows, e.g.,
MaskedAutoregressiveFlow
. In practice it's not really an issue, because the correct normalization is preserved once the flow is trained to some samples.