kazewong / flowMC

Normalizing-flow enhanced sampling package for probabilistic inference in Jax
https://flowmc.readthedocs.io/en/main/
MIT License
200 stars 23 forks source link

Sampling from arrays #142

Closed ColCarroll closed 7 months ago

ColCarroll commented 9 months ago

It looks as though flowMC fails to sample when the initial point is either 1 dimensional, or has size 1.

Consider sampling from a multivariate normal:

import jax
import jax.numpy as jnp
from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline
from flowMC.sampler.MALA import MALA
from flowMC.sampler.HMC import HMC
from flowMC.sampler.Sampler import Sampler
from flowMC.utils.PRNG_keys import initialize_rng_keys

n_chains=10
def log_density(x, data):
  return jnp.sum(-x**2)

n_dim =  2
rng_key_set = initialize_rng_keys(n_chains, seed=42)
model = MaskedCouplingRQSpline(
    n_layers=3, hidden_size=[64, 64], num_bins=8, n_features=n_dim, 
    key=jax.random.PRNGKey(21))
local_sampler = MALA(log_density, True, params={"step_size": 0.1})

nf_sampler = Sampler(
    # added the n_loop_training and n_loop_production
    n_dim=n_dim,
    rng_key_set=rng_key_set,
    data={},
    local_sampler=local_sampler,
    nf_model=model,
    n_chains=n_chains)
  1. If we use
    nf_sampler.sample(1., {})

    the following is thrown:

    ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
  2. If we use:
    nf_sampler.sample(jnp.ones(n_dim), {})  # where n_dim > 1

    the following is thrown:

    ValueError: Incompatible shapes for broadcasting: shapes=[(10, 50, 2), (2, 1)]
  3. If we use
    nf_sampler.sample(jnp.atleast_2d(jnp.ones(n_dim)), {})  # where n_dim > 1

    the code will run.

In case n_dim = 1, last two cases also fail with

ValueError: diag input must be 1d or 2d

I'm not sure if this is intended behavior, but figure I'd flag those!