dirmeier / sbijax

Simulation-based inference in JAX
https://sbijax.rtfd.io
Apache License 2.0
21 stars 2 forks source link

How to sample y and compute log_prob using trained neural likelihood #44

Open Jice-Zeng opened 2 months ago

Jice-Zeng commented 2 months ago

Hi, I want to sample y and compute log_prob using trained neural likelihood for SLCP example as follow:

theta1 = prior_fn().sample(seed=jr.PRNGKey(2),sample_shape=(1,))
print(theta1)

y_observed = simulator_fn(jr.PRNGKey(2), theta1)
print(y_observed)

import haiku as hk
rng_key_seq = hk.PRNGSequence(0)

theta_t, _ = ravel_pytree(theta1)
theta_tt = jnp.tile(theta_t,[y_observed.shape[0], 1]) 
print(theta_tt.shape)

lp = neural_network.apply(params_t, None, method="log_prob", y=y_observed, x=theta_tt)
print(lp)

sample_y = neural_network.apply(params_t, next(rng_key_seq), method="sample",sample_shape=(2000,) ,
                                x=jnp.ones((2000, 1))*theta_tt)
print(sample_y)

The lp is ok, but there is an error when computing sample_y: TypeError: Cannot concatenate arrays with shapes that differ in dimensions other than the one being concatenated: concatenating along dimension 0 for shapes (2000, 5), (2000, 3). I do not see example that samples from trained neural likelihood function, can you give me some hints? Thanks!

Jice-Zeng commented 2 months ago

The error is in the class of MaskedAutoregressiveInferenceFunnel:

    def _forward_and_likelihood_contribution(self, z, x=None, **kwargs):
        z_condition = z
        if x is not None:
            z_condition = jnp.concatenate([z, x], axis=-1)
        y_minus, jac_det = self.decoder(z_condition).sample_and_log_prob(
            seed=hk.next_rng_key()
        )

        y_cond = y_minus
        if x is not None:
            y_cond = jnp.concatenate([y_cond, x], axis=-1)
        # TODO(simon): need to sort the indexes correctly (?)
        # TODO(simon): remote the conditioning here?
        y_plus, lc = self._inner_bijector().forward_and_log_det(z, y_cond)

        y = jnp.concatenate([y_plus, y_minus])
        return y, lc + jac_det

y = jnp.concatenate([y_plus, y_minus])should be y = jnp.concatenate([y_plus, y_minus],axis=1) Then the error disapears.

dirmeier commented 2 months ago

Jice, thanks for the report. I will look at it and fix it as soon as my time allows it. As I was saying in the other thread, we did a major refactor where I likely introduced bugs.