Closed ColCarroll closed 7 months ago
It looks as though flowMC fails to sample when the initial point is either 1 dimensional, or has size 1.
flowMC
Consider sampling from a multivariate normal:
import jax import jax.numpy as jnp from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline from flowMC.sampler.MALA import MALA from flowMC.sampler.HMC import HMC from flowMC.sampler.Sampler import Sampler from flowMC.utils.PRNG_keys import initialize_rng_keys n_chains=10 def log_density(x, data): return jnp.sum(-x**2) n_dim = 2 rng_key_set = initialize_rng_keys(n_chains, seed=42) model = MaskedCouplingRQSpline( n_layers=3, hidden_size=[64, 64], num_bins=8, n_features=n_dim, key=jax.random.PRNGKey(21)) local_sampler = MALA(log_density, True, params={"step_size": 0.1}) nf_sampler = Sampler( # added the n_loop_training and n_loop_production n_dim=n_dim, rng_key_set=rng_key_set, data={}, local_sampler=local_sampler, nf_model=model, n_chains=n_chains)
nf_sampler.sample(1., {})
the following is thrown:
ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
nf_sampler.sample(jnp.ones(n_dim), {}) # where n_dim > 1
ValueError: Incompatible shapes for broadcasting: shapes=[(10, 50, 2), (2, 1)]
nf_sampler.sample(jnp.atleast_2d(jnp.ones(n_dim)), {}) # where n_dim > 1
the code will run.
In case n_dim = 1, last two cases also fail with
n_dim = 1
ValueError: diag input must be 1d or 2d
I'm not sure if this is intended behavior, but figure I'd flag those!
It looks as though
flowMC
fails to sample when the initial point is either 1 dimensional, or has size 1.Consider sampling from a multivariate normal:
the following is thrown:
the following is thrown:
the code will run.
In case
n_dim = 1
, last two cases also fail withI'm not sure if this is intended behavior, but figure I'd flag those!