google-deepmind / distrax

Apache License 2.0
529 stars 32 forks source link

Distrax + Flax bijector error and best practices #263

Open JamesAllingham opened 10 months ago

JamesAllingham commented 10 months ago

I've encountered a small error when implementing Distrax bijectors with Flax conditioner NNs, and I also have a question about best practices for using Distrax with Flax.

The error can be reproduced with the following setup (also in this Colab notebook https://colab.research.google.com/drive/1RLRZul_pHnglcT_-YZ7mcuKLU1qd3w5O?usp=sharing).

class Conditioner(nn.Module):
    event_shape: Sequence[int]
    num_bijector_params: int
    hidden_dims: Sequence[int]

    @nn.compact
    def __call__(self, z: Array, h: Array) -> Array:
        h = jnp.concatenate((z.flatten(), h.flatten()), axis=0)

        for hidden_dim in self.hidden_dims:
            h = nn.Dense(hidden_dim)(h)
            h = nn.relu(h)

        y = nn.Dense(np.prod(self.event_shape) * self.num_bijector_params)(h)
        y = y.reshape(tuple(self.event_shape) + (self.num_bijector_params,))

        return y

class MyModel(nn.Module):
    hidden_dims: Sequence[int]
    num_flows: int
    num_bins: int
    event_shape: Sequence[int]
    conditioner: Optional[KwArgs] = None

    @nn.compact
    def __call__(self, x, y: Optional[Array] = None):
        # base distribution
        output_dim = np.prod(self.event_shape)
        base = distrax.Independent(
            distrax.Normal(loc=jnp.zeros(output_dim,), scale=jnp.ones(output_dim,)), len(self.event_shape)
        )

        # bijector
        # Number of parameters for the rational-quadratic spline:
        # - `num_bins` bin widths
        # - `num_bins` bin heights
        # - `num_bins + 1` knot slopes
        # for a total of `3 * num_bins + 1` parameters.
        num_bijector_params = 3 * self.num_bins + 1

        layers = []
        mask = jnp.arange(0, np.prod(self.event_shape)) % 2
        mask = jnp.reshape(mask, self.event_shape)
        mask = mask.astype(bool)

        def bijector_fn(params: Array):
            return distrax.RationalQuadraticSpline(
                params, range_min=-3.0, range_max=3.0
            )

        h = x.flatten()

        # shared feature extractor
        for hidden_dim in self.hidden_dims:
            h = nn.Dense(hidden_dim)(h)
            h = nn.relu(h)

        for i in range(self.num_flows):
            conditioner = Conditioner(
                event_shape=self.event_shape,
                num_bijector_params=num_bijector_params,
                **(self.conditioner or {}),
            )

            layer = distrax.MaskedCoupling(
                mask=mask,
                bijector=bijector_fn,
                conditioner=functools.partial(conditioner, h=h),
            )

            layers.append(layer)
            mask = ~mask

        bijector = distrax.Inverse(distrax.Chain(layers))
        transformed = distrax.Transformed(base, bijector)

        if y is not None:
            return transformed, transformed.log_prob(y)
        else:
            return transformed

model = MyModel(
    hidden_dims = [64, 32],
    num_flows = 3,
    num_bins = 8,
    event_shape = (6,),
    conditioner = {'hidden_dims': [64, 32]}
)

variables = model.init(random.PRNGKey(0), jnp.empty((28, 28, 1)), y=jnp.empty((6,)))

dist = model.apply(variables, jnp.ones((28, 28, 1)))

dist.event_shape

Which raises the following error:

JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.JaxTransformError)

Thankfully, evaluating log probs, i.e., dist.log_prob(jnp.zeros(6,)), runs without any error.

Any idea why this is happening? Am I doing something wrong when constructing the model?

On that note, I've also found that if I initialize the parameters like this:

variables = model.init(random.PRNGKey(0), jnp.empty((28, 28, 1)))

The parameters for the conditioner are not instantiated. To fix this, I've used the workaround of evaluating the log prob of some dummy data when initializing the model:

variables = model.init(random.PRNGKey(0), jnp.empty((28, 28, 1)), y=jnp.empty((6,)))

But this feels a little hacky to me and suggests that perhaps I am doing something wrong in my model definition. Do you have a set of best practices for using Flax with Distrax (now that Haiku is deprecated)?

Thanks for the help!