astrodeepnet / sbi_experiments

Simulation Based Inference experiments
MIT License
3 stars 3 forks source link

Smooth Normalizing Flow N dimensions #23

Closed Justinezgh closed 2 years ago

Justinezgh commented 2 years ago

Finally works for more than 2d :) I changed the bijector and the NN in the NF and tested it on 6d: notebook

Justinezgh commented 2 years ago

the bijector works :) I learned a 6d gaussian, green is the truth and blue the prediction.

learned with nll: image

learned with score matching to check the score: image

notebook

Justinezgh commented 2 years ago

I tested a first 6d gaussian model (without latent variables):

@jax.jit
def get_batch(seed):

  prior = tfd.Independent(tfd.Uniform(0.1*jnp.ones(d), 0.9*jnp.ones(d)),
                            reinterpreted_batch_ndims=1)

  mu = prior.sample(batch_size, seed=seed)

  batch = jax.vmap(lambda mu, seed: tfd.Independent(tfd.TruncatedNormal(mu,
                                                                        0.08*jnp.ones(d),
                                                                        0.01,
                                                                        0.99),
                                                    1).sample(seed=seed))(mu,jax.random.split(seed, batch_size))
  score = jax.vmap(jax.grad(lambda batch, mu: tfd.Independent(tfd.TruncatedNormal(mu,
                                                                                  0.08*jnp.ones(d), 
                                                                                  0.01,
                                                                                  0.99),1).log_prob(batch)))(batch, mu)
  score += jax.vmap(jax.grad(prior.log_prob))(mu)

  return mu, batch, score

learned without score: (the red point is the true mean) image

learned with loss = 1e-3 * score - nll: image