pymc-devs / pymc-experimental

https://pymc-experimental.readthedocs.io
Other
72 stars 46 forks source link

sample_smc #319

Open myravian opened 3 months ago

myravian commented 3 months ago

Hi there,

I was testing pymc_experimental/inference/smc/sampling.py and noticed the following issues:

Thanks a lot for the SMC blackjax implementation, it's very useful!

Cheers, VIan

PS: here's some code that produces the error

` real_a = 0.2 real_b = 2 x = np.linspace(1, 100) y = real_a * x + real_b + np.random.normal(0, 2, len(x))

with pm.Model() as model: a = pm.Normal("a", mu=10, sigma=10) b = pm.Normal("b", mu=10, sigma=10)

either of the following lines produces an error

# c = pm.Normal("c", mu=10, sigma=10, shape=(1,))
# d = pm.Dirichlet("d", [1, 1])

trace = sample_smc(
    n_particles=1000,
    kernel="HMC",
    inner_kernel_params={
        "step_size": 0.01, 
        "integration_steps": 20,
    },
    iterations_to_diagnose=10,
    target_essn=0.5,
    num_mcmc_steps=10,
)

`

ciguaran commented 3 months ago

Hi, I can tackle this could someone assign the issue to me?

ciguaran commented 3 months ago

@myravian could you try your example running it from this branch? I may have a fix https://github.com/ciguaran/pymc-experimental/tree/ciguaran_fix_smc_bj . Also super interested to know what are you using SMC for, it would be great if it would become an example notebook on how to use it!. let me know.

myravian commented 2 months ago

Unfortunately I still have the same error message:

  File "/local/home/vleboute/work/MULTIGRIS/mgris/sampling_smc_ciguaran.py", line 150, in sample_smc_blackjax
    total_iterations, particles, diagnosis = inference_loop(
                                             ^^^^^^^^^^^^^^^
  File "/local/home/vleboute/work/MULTIGRIS/mgris/sampling_smc_ciguaran.py", line 267, in inference_loop
    n_iter, final_state, _, diagnosis = jax.lax.while_loop(
                                        ^^^^^^^^^^^^^^^^^^^
  File "/local/home/vleboute/work/MULTIGRIS/mgris/sampling_smc_ciguaran.py", line 262, in one_step
    state, info = kernel.step(subk, state)
                  ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/smc/adaptive_tempered.py", line 167, in step_fn
    return kernel(
           ^^^^^^^
  File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/smc/adaptive_tempered.py", line 101, in kernel
    return tempered_kernel(rng_key, state, num_mcmc_steps, lmbda, mcmc_parameters)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/smc/tempered.py", line 143, in kernel
    smc_state, info = smc.base.step(
                      ^^^^^^^^^^^^^^
  File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/smc/base.py", line 140, in step
    particles, update_info = update_fn(keys, particles)
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/smc/tempered.py", line 131, in mcmc_kernel
    state = mcmc_init_fn(position, tempered_logposterior_fn)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/mcmc/hmc.py", line 89, in init
    logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/blackjax/smc/tempered.py", line 126, in tempered_logposterior_fn
    logprior = logprior_fn(position)
               ^^^^^^^^^^^^^^^^^^^^^
  File "/local/home/vleboute/work/MULTIGRIS/mgris/sampling_smc_ciguaran.py", line 380, in logp_fn_wrap
    return logp_fn(*particles)[0]
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/tmpoc516ktt", line 29, in jax_funcified_fgraph
    tensor_variable_13 = dimshuffle_1(d_simplex_)
                         ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/pytensor/link/jax/dispatch/elemwise.py", line 69, in dimshuffle
    res = jnp.transpose(x, op.transposition)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/local/home/vleboute/miniconda3/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 681, in transpose
    return lax.transpose(a, axes_)
           ^^^^^^^^^^^^^^^^^^^^^^^
TypeError: transpose permutation isn't a permutation of operand dimensions, got permutation (0,) for operand shape (1000, 1).
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

I don't have a straightforward and simple illustration of the way I use SMC, but the gist of it is that I ran an astrophysical code to compute predictions corresponding to several 100 thousand parameter sets. Based on a set of observables I try to infer the parameters. Of course there are various issues such as regularity/completeness of grid and interpolation, but the main issue is the complex, multi-modal posterior distributions that we expect. From all the proof of concept and validation tests we did, SMC has been a great way to probe the prior space and to handle such difficult posteriors (provided the kernel parameters are well tuned). I'm by no means an expert in statistics and I rely a lot on empirical knowledge so I'm sure I'm not doing everything right though...!

ciguaran commented 2 months ago

Could you share a full python file that reproduces the error? I've run the example you posted at the very beginning and it does work for me 🤔 .

myravian commented 2 months ago

Here would be the script: `import pymc as pm

from sampling_smc_ciguaran import sample_smc_blackjax as sample_smc

with pm.Model() as model: c = pm.Normal("c", mu=10, sigma=10, shape=(1,)) d = pm.Dirichlet("d", [1, 1])

trace = sample_smc(
    n_particles=1000,
    kernel="HMC",
    inner_kernel_params={
        "step_size": 0.01,  # small values better
        "integration_steps": 20,
    },
    iterations_to_diagnose=10,
    target_essn=0.5,
    num_mcmc_steps=10,
)

` Maybe it has to do with the blackjax/jax versions (1.1.0/0.4.21 in my system)

ciguaran commented 2 months ago

Hi! so I was able to run the example you just shared via installing pymc-experimental from the branch.

pip install git+https://github.com/ciguaran/pymc-experimental@ciguaran_fix_smc_bj

is it possible that you are still using pymc-experimental from master?

myravian commented 2 months ago

You're right, I was not using the proper versions, just tested it and it seems to work fine, thanks for the modifications!