AlexanderFengler / Omission_omission

Study on the effect of ignoring omissions
MIT License
0 stars 0 forks source link

Get rid of deterministics while sampling #1

Open AlexanderFengler opened 1 month ago

AlexanderFengler commented 1 month ago

The new version of PyMC allows to choose which variables you actually want to sample, this allows you to skip all deterministics and just do random variables when running MCMC.

Let's try to see if that gives speedups, the secondary benefit is that it will reduce ram requirements while sampling by a lot.

AlexanderFengler commented 1 month ago

see this issue

AlexanderFengler commented 1 month ago

PyMC >=5.14.0 needs to be installed for this to work.

You get it automatically on fresh install, mamba -c conda-forge pymc

To collect all random variables without deterministics, something like, model.free_RVs(). (You should be able to pass this to the new var_names argument in pm.sample().

Jasonleng commented 1 month ago

@AlexanderFengler The fitting runs well but I got this error after the fitting is done, when I set var_names to free_RVs.


TypeError Traceback (most recent call last) Cell In[93], line 37 33 p_trial1 = pm.Deterministic('p_trial1',p_subj[idx1] * sigma_p + mu_p) 35 pm.CustomDist("choice_rt", v_trial1,a_trial1,z_trial1,t_trial1,theta_trial1,p_trial1, 36 logp=lan_logp_op,observed=df[['rt','response']]) ---> 37 ddm_blog_traces_numpyro_d = pm.sample(target_accept=0.9,nuts_sampler='numpyro', 38 chains=2, draws=1000, tune=1000,initvals={'mu_v':0,'mu_a':0,'mu_theta':0,"mu_z":0,"mu_t":0,'mu_p':0},var_names=hierarchical.free_RVs 39 )

File ~/.conda/envs/lan_pipe3/lib/python3.10/site-packages/pymc/sampling/mcmc.py:691, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, kwargs) 687 if not isinstance(step, NUTS): 688 raise ValueError( 689 "Model can not be sampled with NUTS alone. Your model is probably not continuous." 690 ) --> 691 return _sample_external_nuts( 692 sampler=nuts_sampler, 693 draws=draws, 694 tune=tune, 695 chains=chains, 696 target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8), 697 random_seed=random_seed, 698 initvals=initvals, 699 model=model, 700 var_names=var_names, 701 progressbar=progressbar, 702 idata_kwargs=idata_kwargs, 703 nuts_sampler_kwargs=nuts_sampler_kwargs, 704 kwargs, 705 ) 707 if isinstance(step, list): 708 step = CompoundStep(step)

File ~/.conda/envs/lan_pipe3/lib/python3.10/site-packages/pymc/sampling/mcmc.py:351, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, nuts_sampler_kwargs, kwargs) 348 elif sampler in ("numpyro", "blackjax"): 349 import pymc.sampling.jax as pymc_jax --> 351 idata = pymc_jax.sample_jax_nuts( 352 draws=draws, 353 tune=tune, 354 chains=chains, 355 target_accept=target_accept, 356 random_seed=random_seed, 357 initvals=initvals, 358 model=model, 359 var_names=var_names, 360 progressbar=progressbar, 361 nuts_sampler=sampler, 362 idata_kwargs=idata_kwargs, 363 nuts_sampler_kwargs, 364 ) 365 return idata 367 else:

File ~/.conda/envs/lan_pipe3/lib/python3.10/site-packages/pymc/sampling/jax.py:582, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler) 579 tic2 = datetime.now() 581 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) --> 582 result = _postprocess_samples( 583 jax_fn, 584 raw_mcmc_samples, 585 postprocessing_backend=postprocessing_backend, 586 postprocessing_vectorize=postprocessing_vectorize, 587 ) 588 mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)} 590 if idata_kwargs is None:

File ~/.conda/envs/lan_pipe3/lib/python3.10/site-packages/pymc/sampling/jax.py:194, in _postprocess_samples(jax_fn, raw_mcmc_samples, postprocessing_backend, postprocessing_vectorize) 188 jax_vfn = jax.vmap(jaxfn) 189 , outs = scan( 190 lambda _, x: ((), jax_vfn(x)), 191 (), 192 _device_put(t_raw_mcmc_samples, postprocessing_backend), 193 ) --> 194 return [jnp.swapaxes(t, 0, 1) for t in outs] 195 elif postprocessing_vectorize == "vmap": 196 return jax.vmap(jax.vmap(jax_fn))(_device_put(raw_mcmc_samples, postprocessing_backend))

TypeError: 'NoneType' object is not iterable