pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.72k stars 2.01k forks source link

`var_names` in `sample_numpyro_nuts` not working #6116

Open rtdew1 opened 2 years ago

rtdew1 commented 2 years ago

When trying to selectively save certain variables from sampling using var_names in sample_numpyro_nuts, PyMC throws a mysterious error.

More specifically, I have a large model, but only really care about saving a few of the variables. I'm using numpyro NUTS for sampling on the GPU. To save memory, I found the var_names argument of sample_numpyro_nuts, which seemed to be what I'm looking for. But, whenever I use it, it throws the same mysterious error, copied below. This seems like a bug.

Please provide a minimal, self-contained, and reproducible example.

import numpy as np
import pymc as pm
import pymc.sampling_jax

# True parameter values
size = 100
Y = 1 + np.random.normal(size=size, scale = 1)

basic_model = pm.Model()
with basic_model:
    scale = pm.HalfNormal("scale", sigma=1)
    loc = pm.Normal("loc", mu=0, sigma=10)
    Y_obs = pm.Normal("Y_obs", mu=loc, sigma=scale, observed=Y)

with basic_model:
    trace = pymc.sampling_jax.sample_numpyro_nuts(
        chains = 1, 
        tune = 1000,
        draws = 1000,
        var_names = ["loc"]
    )

Please provide the full traceback.

Complete error traceback ```python --------------------------------------------------------------------------- AttributeError Traceback (most recent call last) /Users/ryandew/routines/code/pymc/sim/mwe.py in () 12 Y_obs = pm.Normal("Y_obs", mu=loc, sigma=scale, observed=Y) 14 with basic_model: ---> 15 trace = pymc.sampling_jax.sample_numpyro_nuts( 16 chains = 1, 17 tune = 1000, 18 draws = 1000, 19 var_names = ["loc"] 20 ) File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling_jax.py:533, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs) 530 print("Sampling time = ", tic3 - tic2, file=sys.stdout) 532 print("Transforming variables...", file=sys.stdout) --> 533 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) 534 result = jax.vmap(jax.vmap(jax_fn))( 535 *jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0]) 536 ) 537 mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)} File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling_jax.py:81, in get_jaxified_graph(inputs, outputs) 75 def get_jaxified_graph( 76 inputs: Optional[List[TensorVariable]] = None, 77 outputs: Optional[List[TensorVariable]] = None, ... 850 def expand(r: Variable) -> Optional[Iterator[Variable]]: --> 851 if r.owner and (not blockers or r not in blockers): 852 return reversed(r.owner.inputs) AttributeError: 'str' object has no attribute 'owner' ```

Please provide any additional information below.

Versions and main components

ricardoV94 commented 2 years ago

@junpenglao said it also fails in pm.sample if someone can confirm

junpenglao commented 2 years ago

I think it might be a feature in the past also in pm.sample, but was removed at some point. +1 that we should make sure the 2 API is unified.