danielward27 / flowjax

https://danielward27.github.io/flowjax/
MIT License
75 stars 10 forks source link

masked autoregressive flow with mixed transformer types #161

Open llaurabatt opened 3 months ago

llaurabatt commented 3 months ago

I am looking into a modification of a regular masked autoregressive flow where the base distribution is an N-dimensional uniform and the first variable does not get transformed, while the rest of the variables get transformed via a rational quadratic spline. I have removed the shuffling in the masked_autoregressive_flow function via removing the _add_default_permute, and modified the _flat_params_to_transformer in the MaskedAutoregressive class to apply an Identity transformer to the first dimension in the following way

    def _flat_params_to_transformer(self, params: Array, y_dim=1):
        """Reshape to dim X params_per_dim, then vmap."""
        dim = self.shape[-1]
        transformer_params = jnp.reshape(params, (dim, -1))
        transformer_params = transformer_params[y_dim:, :]
        transformer = eqx.filter_vmap(self.transformer_constructor)(transformer_params)
        return Concatenate(
            [Identity((y_dim,)), Vmap(transformer, in_axes=eqx.if_array(0))]
        )

My understanding is that in this way the masked_autoregressive_mlp will still produce a set of spline parameters for the first variable, that then never get used, and that this should be harmless. My experiments seem to produce the expected results but I am not sure that this is the most efficient way to go about this or whether I am disregarding anything relevant, so would love to hear your opinion as to how to make the best use of your package. Thanks again for all the amazing work!

danielward27 commented 3 months ago

I think your approach works, but it would have a bit of extra overhead as like you said the masked autoregressive network will still produce a set of (unused) parameters for the identity transformed variables. If you wanted to avoid that, here's another possibility.

What I have done is wrap a masked autoregressive bijection that has dimension matching the dimensionality of the transformed variables, and cond_dim matching the number of identity variables. In each method we pass in the identity transformed variables as conditioning variables to the masked autoregressive bijection. This should be an equivalent architecture, except avoiding the unnecessary computation.

from typing import ClassVar
from flowjax.bijections.bijection import AbstractBijection
from flowjax.bijections.masked_autoregressive import MaskedAutoregressive

class IdentityFirstMaskedAutoregressive(AbstractBijection):
    masked_autoregressive: MaskedAutoregressive
    identity_dim: int
    shape: tuple[int, ...]
    cond_shape: ClassVar[None] = None

    def __init__(self, masked_autoregressive: MaskedAutoregressive):
        self.masked_autoregressive = masked_autoregressive
        self.identity_dim = masked_autoregressive.cond_shape[0]
        self.shape = (self.identity_dim + self.masked_autoregressive.shape[0],)

    def transform(self, x, condition=None):
        y = self.masked_autoregressive.transform(
            x[self.identity_dim :], condition=x[: self.identity_dim]
        )
        return x.at[self.identity_dim :].set(y)

    def transform_and_log_det(self, x, condition=None):
        y, log_det = self.masked_autoregressive.transform_and_log_det(
            x[self.identity_dim :],
            condition=x[: self.identity_dim],
        )
        return x.at[self.identity_dim :].set(y), log_det

    def inverse(self, y, condition=None):
        x = self.masked_autoregressive.inverse(
            y[self.identity_dim :], condition=y[: self.identity_dim]
        )
        return y.at[self.identity_dim :].set(x)

    def inverse_and_log_det(self, y, condition=None):
        x, log_det = self.masked_autoregressive.inverse_and_log_det(
            y[self.identity_dim :], condition=y[: self.identity_dim]
        )
        return y.at[self.identity_dim :].set(x), log_det

If you need to support a conditional version of this, then it should be possible with some concatenating and adjusting of shapes.

In general it could be possible to add support for a mix of transformer types, but e.g. if we assume we have a list of heterogeneous transformers then compilation speed might become an issue, as we can no longer just rely on vmap and would have to loop. Thanks for the support and let me know if you have any questions/issues!

mdmould commented 1 month ago

This is a bit late, but another option is to defined individual bijections for the non-transformed variables and the remaining ones, and then stack them together into a single bijection:

import jax
import jax.numpy as jnp
from flowjax.bijections import Identity, RationalQuadraticSpline, MaskedAutoregressive, Concatenate
from flowjax.distributions import Uniform, Transformed

N = 5

base_dist = Uniform(minval = -jnp.ones(N), maxval = jnp.ones(N))

bijections = [
    Identity(shape = (1,)),
    MaskedAutoregressive(
        key = jax.random.PRNGKey(0),
        transformer = RationalQuadraticSpline(knots = 5, interval = 1.0),
        dim = N - 1,
        nn_width = 10,
        nn_depth = 1,
    ),
]

# use Concatenate as it stacks bijections along an *existing* axis
bijection = Concatenate(bijections)

flow = Transformed(base_dist, bijection)

You could wrap this in a constructor with vmap and permutations only within the N-1 dimensions, though maybe the Concatenate has some overhead.

danielward27 commented 1 month ago

You can do that, but note that the transform of the transformed dimensions will be independent of the identity transformed variables if you do

danielward27 commented 1 month ago

It could be possible to support a transformer with shape/dimension matching the shape of the total bijection (rather than only scalar bijections), in which case you could Stack/Concatenate transformers as you please. The main issue is you need a reliable way to know/specify which parameters are involved in transforming which dimensions. The only way I can think of would be to force passing a pytree of ints with structure matching the parameters, specifying which output dimension the parameters pertain to (quite cumbersome).