Closed Justinezgh closed 2 years ago
the bijector works :) I learned a 6d gaussian, green is the truth and blue the prediction.
learned with nll:
learned with score matching to check the score:
I tested a first 6d gaussian model (without latent variables):
@jax.jit
def get_batch(seed):
prior = tfd.Independent(tfd.Uniform(0.1*jnp.ones(d), 0.9*jnp.ones(d)),
reinterpreted_batch_ndims=1)
mu = prior.sample(batch_size, seed=seed)
batch = jax.vmap(lambda mu, seed: tfd.Independent(tfd.TruncatedNormal(mu,
0.08*jnp.ones(d),
0.01,
0.99),
1).sample(seed=seed))(mu,jax.random.split(seed, batch_size))
score = jax.vmap(jax.grad(lambda batch, mu: tfd.Independent(tfd.TruncatedNormal(mu,
0.08*jnp.ones(d),
0.01,
0.99),1).log_prob(batch)))(batch, mu)
score += jax.vmap(jax.grad(prior.log_prob))(mu)
return mu, batch, score
learned without score: (the red point is the true mean)
learned with loss = 1e-3 * score - nll:
Finally works for more than 2d :) I changed the bijector and the NN in the NF and tested it on 6d: notebook