Closed vallis closed 1 year ago
Hello, thanks!
This should be possible, using the EmbedCondition bijection. I think this should work
import jax.random as jr
import jax.numpy as jnp
from flowjax.flows import block_neuralAutoregressive_flow
from flowjax.distributions import Normal, Transformed
from flowjax.bijections import EmbedCondition
from equinox.nn import MLP
# Create data, from which we want to infer p(x|u)
key, x_key, cond_key = jr.split(jr.PRNGKey(0), 3)
u = jr.uniform(cond_key, (10000, 2), minval=0, maxval=5)
x = jr.uniform(x_key, shape=u.shape, maxval=u)
key, subkey = jr.split(jr.PRNGKey(0))
raw_dim = u.shape[1]
embedded_dim = 1 # Here we embed from 2d to 1d
# Define flow with cond_dim==embedded_dim. Just take bijection component.
bnaf_bijection = block_neural_autoregressive_flow(
key=subkey,
base_dist=Normal(jnp.zeros(x.shape[1])),
cond_dim=embedded_dim,
).bijection
# Define a new bijection that includes the embedding network (here an MLP)
key, subkey = jr.split(key)
embedding_net = MLP(
in_size=raw_dim, out_size=embedded_dim, width_size=5, depth=1, key=subkey
) # callable mapping raw_condition -> embedded_condition
bnaf_bijection_with_embedding_net = EmbedCondition(
bijection=bnaf_bijection,
embedding_net=embedding_net,
raw_cond_shape=(raw_dim, ),
)
# Create flow with new bijection
flow = Transformed(Normal(jnp.zeros(x.shape[1])), bnaf_bijection_with_embedding_net)
The MLP parameters will be trained (by default), as they are inexact arrays so will be filtered by equinox. Let me know if you have any thoughts/questions.
Hello Daniel, excellent package!
I would like to do conditional estimation, but apply a learnable transformation to the conditioning variables (the “u” in your example) before they are fed to the flow, and hopefully optimize the transformation as part of the fit.
Can I do it in flowjax without too much surgery?