dirmeier / sbijax

Simulation-based inference in JAX
https://sbijax.rtfd.io
Apache License 2.0
21 stars 2 forks source link

Can't successfuly run the slcp-snle #42

Open Jice-Zeng opened 2 months ago

Jice-Zeng commented 2 months ago

Hi Simon, I am trying to run the example of SCLP using surjective layers, actually the file slcp-snle. y_obs = jnp.array([[ -0.9707123, -2.9461224, -0.4494722, -3.4231849, -0.13285634, -3.364017, -0.85367596, -2.4271638, ]]) fns = prior_fn, simulator_fn

neural_network = make_maf(8, n_layer_dimensions=[8, 8, 5, 5, 5]) snl = NLE(fns, neuralnetwork) optimizer = optax.adam(1e-4) data, = snl.simulate_data(jr.PRNGKey(0), n_simulations=10_000) params_t, losses = snl.fit( jr.PRNGKey(0), data=data, n_early_stopping_patience=100, optimizer=optimizer, n_iter=1000 ) Unfortunately, I got an error: ValueError: Incompatible shapes for broadcasting: shapes=[(100, 5), (8,)]

If I design all layers to be bijective, such as n_layer_dimensions=[8, 8, 8, 8, 8]. the implementation works well. So I guess the issue resulted from the use of surjective layer. I tried many times to modify the codes, but the error still persisted, can you give me some ideas? Thanks!

I also tried another nerual network: MaskedCouplingInferenceFunnel below:

def _make_maf(
    n_dimension,
    n_layer_dimensions,
    hidden_sizes,
    activation,
):
    def _bijector_fn(params):
        means, log_scales = unstack(params, -1) #jnp.split(params, 2, -1) #
        return distrax.ScalarAffine(means, jnp.exp(log_scales))

    def _decoder_fn(dims):
        def fn(z):
            params = surjectors_mlp(dims, activation=activation)(z)
            mu, log_scale = jnp.split(params, 2, -1)
            return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale)))
        return fn

    def _conditioner(n_dim):
        return hk.Sequential(
            [
                surjectors_mlp(
                    list(hidden_sizes) + [2 * n_dim],
                    activation=activation,
                ),
                hk.Reshape((n_dim, 2)),
            ]
        )

    @hk.transform
    def _flow(method, **kwargs):
        layers = []
        order = jnp.arange(n_dimension)
        curr_dim = n_dimension
        for i, n_dim_curr_layer in enumerate(n_layer_dimensions):
            # layer is dimensionality preserving
            if n_dim_curr_layer == curr_dim:
#
                layer = MaskedCoupling(
                    mask=make_alternating_binary_mask(curr_dim, i % 2 == 0),
                    conditioner=_conditioner(curr_dim),
                    bijector_fn=_bijector_fn,
                )

                order = order[::-1]

            elif n_dim_curr_layer < curr_dim:

                n_latent = n_dim_curr_layer
                layer = MaskedCouplingInferenceFunnel(
                    n_keep=n_latent,
                    decoder=_decoder_fn(
                        list(hidden_sizes) + [2 * (curr_dim - n_latent)]
                    ),
                    conditioner=surjectors_mlp(
                        list(hidden_sizes) + [2 * curr_dim],
                        activation=activation,
                    ),
                    bijector_fn=_bijector_fn,
                )
                curr_dim = n_latent

                order = order[::-1]
                order = order[:curr_dim] - jnp.min(order[:curr_dim])
            else:
                raise ValueError(
                    f"n_dimension at layer {i} is layer than the dimension of"
                    f" the following layer {i + 1}"
                )
            layers.append(layer)
            layers.append(Permutation(order, 1))
        chain = Chain(layers[:-1]) #Chain(layers[:-1])

        base_distribution = distrax.Independent(
            distrax.Normal(jnp.zeros(n_dimension), jnp.ones(n_dimension)),
            1,
        )
        td = TransformedDistribution(base_distribution, chain)
        return td(method, **kwargs)

    return _flow

I got error of ValueError: too many values to unpack (expected 2). Thanks for the contribution to the library and looking forward to your reply.

dirmeier commented 2 months ago

Hi,

Yes, you are right. When refactoring the package for the submission, I introduced an error in make_maf. I push a fix later. This should be

    base_distribution = distrax.Independent(
        distrax.Normal(jnp.zeros(curr_dim), jnp.ones(curr_dim)),
        1,
    )

and not n_dimension.

Cheers, Simon

Jice-Zeng commented 2 months ago

Yes, it works after I changed n_dimension into curr_dim. I also tried to implement different surjective layer, such as MaskedCouplingInferenceFunnel, as you see the codes I pasted in the issue. Even I modified n_dimension into curr_dim, I still got error:

75     def _bijector_fn(params):
     76 #        print(params.shape)
---> 77         means, log_scales = unstack(params, -1) 
     80         return distrax.ScalarAffine(means, jnp.exp(log_scales))

ValueError: too many values to unpack (expected 2) Do you have any idea for implementing MaskedCouplingInferenceFunnel?

dirmeier commented 2 months ago

The issue here is that a MADE network outputs different paramater shapes than a MLP. I think in the case, you need to use the following bijector

def bijector_fn(params):
    shift, log_scale = jnp.split(params, 2, axis=-1)
    return distrax.ScalarAffine(shift, jnp.exp(log_scale))
dirmeier commented 2 months ago

Thanks for reporting all this! I need to put all that in the docu

Jice-Zeng commented 2 months ago

The issue here is that a MADE network outputs different paramater shapes than a MLP. I think in the case, you need to use the following bijector

def bijector_fn(params):
    shift, log_scale = jnp.split(params, 2, axis=-1)
    return distrax.ScalarAffine(shift, jnp.exp(log_scale))

Hi Dirmeier, I think I misled you. The original make_maf works well after changing n_dimension into curr_dim. The function of make_maf uses MaskedAutoregressive as bijective layer, and AffineMaskedAutoregressiveInferenceFunnel as surjective layer. The def _bijector_fn(params): means, log_scales = unstack(params, -1) return distrax.ScalarAffine(means, jnp.exp(log_scales)) also works well.

My issue is that I tried to implement affine (surjective) masked coupling flow, with bijective layer of MaskedCoupling and surjective layer of MaskedCouplingInferenceFunnel, I open another issue #46 to avoid confusion for other readers.