danielward27 / flowjax

https://danielward27.github.io/flowjax/
MIT License
82 stars 10 forks source link

Infinite Loss When Using Uniform Base Distribution to Model Gaussian Data with MAF #147

Closed llaurabatt closed 5 months ago

llaurabatt commented 5 months ago

I am new to your package and am playing with your Bounded Flow example. I am using MAF with a rational quadratic spline transformer and want to use a standard uniform as a base distribution. If I simulate $x$ from a standard gaussian and train the flow with MLE, I immediately get an infinite loss and samples from the flow do not correctly recover the target.

nvars = 2
key, x_key = jr.split(jr.PRNGKey(0))
x = jr.normal(x_key, shape=(5000, nvars))  

key, subkey = jr.split(jr.PRNGKey(0))
base_distr = flowjax.distributions._StandardUniform((nvars,))

# Create the flow
untrained_flow = masked_autoregressive_flow(
    key=subkey,
    base_dist=base_distr,
    transformer=RationalQuadraticSpline(knots=8, interval=4),
)

key, subkey = jr.split(key)
# Train 
flow, losses = fit_to_data(
    key=subkey,
    dist=untrained_flow,
    x=x,
    learning_rate=5e-4,
    max_patience=10,
    max_epochs=70,
)

This happens as well if I use a Uniform with a larger support

base_distr = flowjax.distributions.Uniform(minval=jnp.ones(nvars)*-3, maxval=jnp.ones(nvars)*3)

I have as well tried to build an "unbounded uniform" class for my base distribution where I pass uniform samples through an inverse tanh the same way you do in the example

class UnboundedUniform(AbstractTransformed):
    base_dist: flowjax.distributions._StandardUniform
    bijection: Chain

    def __init__(self, shape):
        eps = 1e-7 
        self.base_dist = flowjax.distributions._StandardUniform(shape)
        affine_transformation = Affine(loc=-jnp.ones(shape) + eps, scale=(1 - eps) * jnp.ones(shape)*2)
        inverse_tanh_transformation = Invert(Tanh(shape=shape))
        self.bijection = Chain([affine_transformation, inverse_tanh_transformation])

base_distr = UnboundedUniform((nvars,))

but I get the same behavior. I would like to know whether this is expected and I am doing something wrong/silly or there is any workaround to this. Thank you for the help and all the good work!

danielward27 commented 5 months ago

I think you are close with the second approach, however, if your base distribution is defined on an interval, it seems a shame to not utilise the spline on that same interval. Here's the approach that came to mind for me: 1) Define the base distribution and the transformer on [-1,1]. This will make a flow with support on [-1, 1]. 2) Transform after the flow layers to unbounded with inverse tanh.

As uniform is itself a transformed distribution (a _StandardUniform with an Affine transformation), if we want we can use the merge_transforms method to explicitly combine the transformations, such that the base distribution is _StandardUniform.

One extra thing - we need to mark the affine transformation as NonTrainable, as we want that to always map to the [-1, 1] interval in order to match the spline.


import jax.random as jr
from flowjax.flows import masked_autoregressive_flow
from flowjax.bijections import RationalQuadraticSpline, Invert, Tanh
from flowjax.train import fit_to_data
import matplotlib.pyplot as plt
from flowjax.distributions import Uniform, Transformed, _StandardUniform
import jax.numpy as jnp
from flowjax.wrappers import NonTrainable
import equinox as eqx

nvars = 2
key, subkey = jr.split(jr.PRNGKey(0))
x = jr.normal(subkey, shape=(5000, nvars))  
base_dist = Uniform(-jnp.ones(nvars), jnp.ones(nvars))

transformer = RationalQuadraticSpline(knots=8, interval=1)

flow = masked_autoregressive_flow(
    key=subkey,
    base_dist=base_dist,
    transformer=transformer,
    ) # Support on [-1, 1]

flow = Transformed(
    flow, Invert(Tanh(flow.shape))
) # Unbounded support

flow = flow.merge_transforms()

assert isinstance(flow.base_dist, _StandardUniform)

flow = eqx.tree_at(
    where=lambda flow: flow.bijection.bijections[0],
    pytree=flow,
    replace_fn=NonTrainable,
    )

key, subkey = jr.split(key)

# Train 
flow, losses = fit_to_data(
    key=subkey,
    dist=flow,
    x=x,
    learning_rate=1e-3,
    max_patience=10,
    max_epochs=70,
)

plt.scatter(*flow.sample(key, (100, )).T, label="flow")
plt.scatter(*x[:100].T, label="target")
plt.legend()

A couple things: 1) Due to the issue fixed in https://github.com/danielward27/flowjax/pull/148, you may have to use main (or wait for the next release) for the above to work. 2) Things could be made a little simpler if RationalQuadraticSpline supported passing an upper and lower interval (rather than enforcing it to be symmetric around 0), so if I get time I will probably add this.

Hope this helps and let me know if you have any questions.

danielward27 commented 5 months ago

I'll close this for now, but as an aside I should note there can occasionally be numerical instability with the rational quadratic spline, leading to nans. I haven't dug into it extensively, but one of the issues is fixed here https://github.com/danielward27/flowjax/pull/152.

12kleingordon34 commented 4 months ago

Hi @danielward27, I was playing around with conditioning the MAF and cooked up a simple toy example based on your response to this issue and thought it was best to ask a Q here rather than make an entirely new thread.

I've roughly from generated data from $X \sim \text{Bernoulli}(p=0.5),~Y\mid X \sim \mathcal{N}(5X, 1)$. I would expect the model to learn a mixture of Gaussians, learning one gaussian at $\mu=0$ if $X=0$, and another at $\mu=5$ if $X=1$. However the flow instead seems to learn the same distribution for both conditioning cases. Is this a possible error in the code or is it just the nature of the data and the Normalizing Flow I'm currently using?

I've pasted the code used below:

import jax
import jax.random as jr
import jax.numpy as jnp
import matplotlib.pyplot as plt
import equinox as eqx

from flowjax.flows import masked_autoregressive_flow
from flowjax.bijections import RationalQuadraticSpline, Invert, Tanh
from flowjax.train import fit_to_data
from flowjax.distributions import Uniform, Transformed, _StandardUniform
from flowjax.wrappers import NonTrainable

# Enable 64-bit mode
jax.config.update('jax_enable_x64', True)

# Set up random keys
keys = jr.split(jr.PRNGKey(0), 6)

# Set up data
N = 5000
nvars = 1
x_cond = jnp.expand_dims(jnp.hstack([jnp.ones(int(N/2))*0, jnp.ones(int(N/2))*1]), axis=1)
y = jr.normal(keys[1], shape=(N, nvars)) + 5 * x_cond

# Set up flow model
base_dist = Uniform(-jnp.ones(nvars), jnp.ones(nvars))
transformer = RationalQuadraticSpline(knots=10, interval=1)
flow = masked_autoregressive_flow(
    key=keys[2],
    base_dist=base_dist,
    transformer=transformer,
    flow_layers=10,
    nn_width=50,
    nn_depth=5,
    cond_dim=x_cond.shape[1],
)
flow = Transformed(flow, Invert(Tanh(flow.shape)))  # Unbounded support
flow = flow.merge_transforms()
assert isinstance(flow.base_dist, _StandardUniform)
flow = eqx.tree_at(
    where=lambda flow: flow.bijection.bijections[0],
    pytree=flow,
    replace_fn=NonTrainable,
)

# Train the model
flow, losses = fit_to_data(
    key=keys[3],
    dist=flow,
    x=y,
    learning_rate=1e-3,
    max_patience=10,
    max_epochs=400,
    condition=x_cond,
)

# Sample from the model
flow_1 = flow.sample(keys[4], condition=jnp.ones((5000,1)))[:,0]
flow_0 = flow.sample(keys[5], condition=jnp.zeros((5000,1)))[:,0]

# Separate the data based on condition
y_1 = y[x_cond==1]
y_0 = y[x_cond==0]

# Plot the results
fig, ax = plt.subplots(2,2, sharex=True)
ax[0,0].hist(flow_1, label='flow for x=1')
ax[0,0].set_title('flow for x=1')
ax[0,1].hist(y_1, label='true for x=1')
ax[0,1].set_title('true for x=1')
ax[1,0].hist(flow_0, label='flow for x=0')
ax[1,0].set_title('flow for x=0')
ax[1,1].hist(y_0, label='true for x=0')
ax[1,1].set_title('true for x=0')
plt.show()
danielward27 commented 4 months ago

Hmm I have just dug into this and it is a bug/suboptimal implementation for conditioning/masking. In the univariate conditional case, the MAF layer should be roughly equivalent to a model that uses a fully connected MLP taking in the conditioning variable and outputting the transformer parameters, but currently all the weights are masked out, meaning you got the results you observed (no dependency on the conditioning variable). I'll try to have this fixed today for you.

Briefly, masking is controlled by assigning ranks to the nodes of each layer (see figure 1 here https://arxiv.org/pdf/1502.03509), by convention, the input ranks for the conditioning variables are set to -1 to avoid masking to any of the other nodes. However, the hidden layer ranks should also include -1, but they don't meaning conditioning dependency is lost to the input with rank 0. For multidimensional examples with permutations, this likely doesn't make much difference, but for the univariate case, the problem becomes clear.

Thanks so much for the report and apologies for the inconvenience!

12kleingordon34 commented 4 months ago

Thanks so much for responding and getting a fix for this so promptly! Loving the package so far -- super clean, elegant and straightforward to use!

danielward27 commented 4 months ago

This should be fixed now by https://github.com/danielward27/flowjax/pull/162, let me know if you have any other issues!

danielward27 commented 2 months ago

Just to mention, on the latest version (12.4.0) it should be possible to avoid the need for eqx.tree_at here, by instead using

from flowjax.wrappers import non_trainable
base_dist = non_trainable(Uniform(-jnp.ones(nvars), jnp.ones(nvars)))