danielward27 / flowjax

https://danielward27.github.io/flowjax/
MIT License
101 stars 14 forks source link

Transformation for conditioning variables #98

Closed vallis closed 1 year ago

vallis commented 1 year ago

Hello Daniel, excellent package!

I would like to do conditional estimation, but apply a learnable transformation to the conditioning variables (the “u” in your example) before they are fed to the flow, and hopefully optimize the transformation as part of the fit.

Can I do it in flowjax without too much surgery?

danielward27 commented 1 year ago

Hello, thanks!

This should be possible, using the EmbedCondition bijection. I think this should work

import jax.random as jr
import jax.numpy as jnp
from flowjax.flows import block_neuralAutoregressive_flow
from flowjax.distributions import Normal, Transformed
from flowjax.bijections import EmbedCondition
from equinox.nn import MLP

# Create data, from which we want to infer p(x|u)
key, x_key, cond_key = jr.split(jr.PRNGKey(0), 3)
u = jr.uniform(cond_key, (10000, 2), minval=0, maxval=5)
x = jr.uniform(x_key, shape=u.shape, maxval=u)
key, subkey = jr.split(jr.PRNGKey(0))

raw_dim = u.shape[1]
embedded_dim = 1  # Here we embed from 2d to 1d

# Define flow with cond_dim==embedded_dim. Just take bijection component.
bnaf_bijection = block_neural_autoregressive_flow(
    key=subkey,
    base_dist=Normal(jnp.zeros(x.shape[1])),
    cond_dim=embedded_dim,
).bijection

# Define a new bijection that includes the embedding network (here an MLP)
key, subkey = jr.split(key)

embedding_net = MLP(
    in_size=raw_dim, out_size=embedded_dim, width_size=5, depth=1, key=subkey
    ) # callable mapping raw_condition -> embedded_condition

bnaf_bijection_with_embedding_net = EmbedCondition(
    bijection=bnaf_bijection,
    embedding_net=embedding_net,  
    raw_cond_shape=(raw_dim, ),
)

# Create flow with new bijection
flow = Transformed(Normal(jnp.zeros(x.shape[1])), bnaf_bijection_with_embedding_net)

The MLP parameters will be trained (by default), as they are inexact arrays so will be filtered by equinox. Let me know if you have any thoughts/questions.