When trying to selectively save certain variables from sampling using var_names in sample_numpyro_nuts, PyMC throws a mysterious error.
More specifically, I have a large model, but only really care about saving a few of the variables. I'm using numpyro NUTS for sampling on the GPU. To save memory, I found the var_names argument of sample_numpyro_nuts, which seemed to be what I'm looking for. But, whenever I use it, it throws the same mysterious error, copied below. This seems like a bug.
Please provide a minimal, self-contained, and reproducible example.
import numpy as np
import pymc as pm
import pymc.sampling_jax
# True parameter values
size = 100
Y = 1 + np.random.normal(size=size, scale = 1)
basic_model = pm.Model()
with basic_model:
scale = pm.HalfNormal("scale", sigma=1)
loc = pm.Normal("loc", mu=0, sigma=10)
Y_obs = pm.Normal("Y_obs", mu=loc, sigma=scale, observed=Y)
with basic_model:
trace = pymc.sampling_jax.sample_numpyro_nuts(
chains = 1,
tune = 1000,
draws = 1000,
var_names = ["loc"]
)
Please provide the full traceback.
Complete error traceback
```python
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
/Users/ryandew/routines/code/pymc/sim/mwe.py in ()
12 Y_obs = pm.Normal("Y_obs", mu=loc, sigma=scale, observed=Y)
14 with basic_model:
---> 15 trace = pymc.sampling_jax.sample_numpyro_nuts(
16 chains = 1,
17 tune = 1000,
18 draws = 1000,
19 var_names = ["loc"]
20 )
File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling_jax.py:533, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
530 print("Sampling time = ", tic3 - tic2, file=sys.stdout)
532 print("Transforming variables...", file=sys.stdout)
--> 533 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
534 result = jax.vmap(jax.vmap(jax_fn))(
535 *jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
536 )
537 mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling_jax.py:81, in get_jaxified_graph(inputs, outputs)
75 def get_jaxified_graph(
76 inputs: Optional[List[TensorVariable]] = None,
77 outputs: Optional[List[TensorVariable]] = None,
...
850 def expand(r: Variable) -> Optional[Iterator[Variable]]:
--> 851 if r.owner and (not blockers or r not in blockers):
852 return reversed(r.owner.inputs)
AttributeError: 'str' object has no attribute 'owner'
```
Please provide any additional information below.
Versions and main components
PyMC version: 4.1.3
Aesara version: 2.7.7
Python version: 3.10.5
Operating system: MacOS (but the error also happens on linux)
When trying to selectively save certain variables from sampling using
var_names
insample_numpyro_nuts
, PyMC throws a mysterious error.More specifically, I have a large model, but only really care about saving a few of the variables. I'm using numpyro NUTS for sampling on the GPU. To save memory, I found the
var_names
argument ofsample_numpyro_nuts
, which seemed to be what I'm looking for. But, whenever I use it, it throws the same mysterious error, copied below. This seems like a bug.Please provide a minimal, self-contained, and reproducible example.
Please provide the full traceback.
Complete error traceback
```python --------------------------------------------------------------------------- AttributeError Traceback (most recent call last) /Users/ryandew/routines/code/pymc/sim/mwe.py inPlease provide any additional information below.
Versions and main components