Closed llaurabatt closed 5 months ago
I think you are close with the second approach, however, if your base distribution is defined on an interval, it seems a shame to not utilise the spline on that same interval. Here's the approach that came to mind for me: 1) Define the base distribution and the transformer on [-1,1]. This will make a flow with support on [-1, 1]. 2) Transform after the flow layers to unbounded with inverse tanh.
As uniform is itself a transformed distribution (a _StandardUniform
with an Affine
transformation), if we want we can use the merge_transforms
method to explicitly combine the transformations, such that the base distribution is _StandardUniform
.
One extra thing - we need to mark the affine transformation as NonTrainable
, as we want that to always map to the [-1, 1] interval in order to match the spline.
import jax.random as jr
from flowjax.flows import masked_autoregressive_flow
from flowjax.bijections import RationalQuadraticSpline, Invert, Tanh
from flowjax.train import fit_to_data
import matplotlib.pyplot as plt
from flowjax.distributions import Uniform, Transformed, _StandardUniform
import jax.numpy as jnp
from flowjax.wrappers import NonTrainable
import equinox as eqx
nvars = 2
key, subkey = jr.split(jr.PRNGKey(0))
x = jr.normal(subkey, shape=(5000, nvars))
base_dist = Uniform(-jnp.ones(nvars), jnp.ones(nvars))
transformer = RationalQuadraticSpline(knots=8, interval=1)
flow = masked_autoregressive_flow(
key=subkey,
base_dist=base_dist,
transformer=transformer,
) # Support on [-1, 1]
flow = Transformed(
flow, Invert(Tanh(flow.shape))
) # Unbounded support
flow = flow.merge_transforms()
assert isinstance(flow.base_dist, _StandardUniform)
flow = eqx.tree_at(
where=lambda flow: flow.bijection.bijections[0],
pytree=flow,
replace_fn=NonTrainable,
)
key, subkey = jr.split(key)
# Train
flow, losses = fit_to_data(
key=subkey,
dist=flow,
x=x,
learning_rate=1e-3,
max_patience=10,
max_epochs=70,
)
plt.scatter(*flow.sample(key, (100, )).T, label="flow")
plt.scatter(*x[:100].T, label="target")
plt.legend()
A couple things:
1) Due to the issue fixed in https://github.com/danielward27/flowjax/pull/148, you may have to use main (or wait for the next release) for the above to work.
2) Things could be made a little simpler if RationalQuadraticSpline
supported passing an upper and lower interval (rather than enforcing it to be symmetric around 0), so if I get time I will probably add this.
Hope this helps and let me know if you have any questions.
I'll close this for now, but as an aside I should note there can occasionally be numerical instability with the rational quadratic spline, leading to nans. I haven't dug into it extensively, but one of the issues is fixed here https://github.com/danielward27/flowjax/pull/152.
Hi @danielward27, I was playing around with conditioning the MAF and cooked up a simple toy example based on your response to this issue and thought it was best to ask a Q here rather than make an entirely new thread.
I've roughly from generated data from $X \sim \text{Bernoulli}(p=0.5),~Y\mid X \sim \mathcal{N}(5X, 1)$. I would expect the model to learn a mixture of Gaussians, learning one gaussian at $\mu=0$ if $X=0$, and another at $\mu=5$ if $X=1$. However the flow instead seems to learn the same distribution for both conditioning cases. Is this a possible error in the code or is it just the nature of the data and the Normalizing Flow I'm currently using?
I've pasted the code used below:
import jax
import jax.random as jr
import jax.numpy as jnp
import matplotlib.pyplot as plt
import equinox as eqx
from flowjax.flows import masked_autoregressive_flow
from flowjax.bijections import RationalQuadraticSpline, Invert, Tanh
from flowjax.train import fit_to_data
from flowjax.distributions import Uniform, Transformed, _StandardUniform
from flowjax.wrappers import NonTrainable
# Enable 64-bit mode
jax.config.update('jax_enable_x64', True)
# Set up random keys
keys = jr.split(jr.PRNGKey(0), 6)
# Set up data
N = 5000
nvars = 1
x_cond = jnp.expand_dims(jnp.hstack([jnp.ones(int(N/2))*0, jnp.ones(int(N/2))*1]), axis=1)
y = jr.normal(keys[1], shape=(N, nvars)) + 5 * x_cond
# Set up flow model
base_dist = Uniform(-jnp.ones(nvars), jnp.ones(nvars))
transformer = RationalQuadraticSpline(knots=10, interval=1)
flow = masked_autoregressive_flow(
key=keys[2],
base_dist=base_dist,
transformer=transformer,
flow_layers=10,
nn_width=50,
nn_depth=5,
cond_dim=x_cond.shape[1],
)
flow = Transformed(flow, Invert(Tanh(flow.shape))) # Unbounded support
flow = flow.merge_transforms()
assert isinstance(flow.base_dist, _StandardUniform)
flow = eqx.tree_at(
where=lambda flow: flow.bijection.bijections[0],
pytree=flow,
replace_fn=NonTrainable,
)
# Train the model
flow, losses = fit_to_data(
key=keys[3],
dist=flow,
x=y,
learning_rate=1e-3,
max_patience=10,
max_epochs=400,
condition=x_cond,
)
# Sample from the model
flow_1 = flow.sample(keys[4], condition=jnp.ones((5000,1)))[:,0]
flow_0 = flow.sample(keys[5], condition=jnp.zeros((5000,1)))[:,0]
# Separate the data based on condition
y_1 = y[x_cond==1]
y_0 = y[x_cond==0]
# Plot the results
fig, ax = plt.subplots(2,2, sharex=True)
ax[0,0].hist(flow_1, label='flow for x=1')
ax[0,0].set_title('flow for x=1')
ax[0,1].hist(y_1, label='true for x=1')
ax[0,1].set_title('true for x=1')
ax[1,0].hist(flow_0, label='flow for x=0')
ax[1,0].set_title('flow for x=0')
ax[1,1].hist(y_0, label='true for x=0')
ax[1,1].set_title('true for x=0')
plt.show()
Hmm I have just dug into this and it is a bug/suboptimal implementation for conditioning/masking. In the univariate conditional case, the MAF layer should be roughly equivalent to a model that uses a fully connected MLP taking in the conditioning variable and outputting the transformer parameters, but currently all the weights are masked out, meaning you got the results you observed (no dependency on the conditioning variable). I'll try to have this fixed today for you.
Briefly, masking is controlled by assigning ranks to the nodes of each layer (see figure 1 here https://arxiv.org/pdf/1502.03509), by convention, the input ranks for the conditioning variables are set to -1 to avoid masking to any of the other nodes. However, the hidden layer ranks should also include -1, but they don't meaning conditioning dependency is lost to the input with rank 0. For multidimensional examples with permutations, this likely doesn't make much difference, but for the univariate case, the problem becomes clear.
Thanks so much for the report and apologies for the inconvenience!
Thanks so much for responding and getting a fix for this so promptly! Loving the package so far -- super clean, elegant and straightforward to use!
This should be fixed now by https://github.com/danielward27/flowjax/pull/162, let me know if you have any other issues!
Just to mention, on the latest version (12.4.0) it should be possible to avoid the need for eqx.tree_at
here, by instead using
from flowjax.wrappers import non_trainable
base_dist = non_trainable(Uniform(-jnp.ones(nvars), jnp.ones(nvars)))
I am new to your package and am playing with your Bounded Flow example. I am using MAF with a rational quadratic spline transformer and want to use a standard uniform as a base distribution. If I simulate $x$ from a standard gaussian and train the flow with MLE, I immediately get an infinite loss and samples from the flow do not correctly recover the target.
This happens as well if I use a Uniform with a larger support
I have as well tried to build an "unbounded uniform" class for my base distribution where I pass uniform samples through an inverse tanh the same way you do in the example
but I get the same behavior. I would like to know whether this is expected and I am doing something wrong/silly or there is any workaround to this. Thank you for the help and all the good work!