pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.48k stars 1.97k forks source link

BUG: Regression in jax translation from 5.12 -> 5.13 #7263

Open velochy opened 2 months ago

velochy commented 2 months ago

Describe the issue:

I have a few models where I have to do some rather complex tensor manipulation, and moving from 5.12 to 5.13 quite a few of them broke down with JAX errors.

As the models themselves are big and unwieldy, I have tried to re-create the same issue with a toy example. As you can see, it needs to be quite convoluted to illicit the error (requiring a model dimension, a call to pt.concatenate and pt.set_subtensor), but I do run into it with more complex actual use cases as well.

I have managed to work around it i some cases by avoiding pt.concatenate and instead just creating an empty tensor and setting it's parts via set_subtensor, but I have one model where even that runs into issues. So it would be very nice if it worked like it used to before :)

The facts of the case:

Reproduceable code example:

import pymc as pm
from pymc.sampling import jax as pm_jax
import pytensor.tensor as pt
import numpy as np

obs = np.array([
    [1,0,1,0,1,0],
    [0,1,1,0,1,0], 
])
ns = [3,3]

with pm.Model() as model:
    model.add_coord('mw',range(6))

    odds = pt.zeros( (len(ns),model.dim_lengths['mw']) )
    modds = pm.Normal('N',shape=(len(ns),model.dim_lengths['mw']//2 - 1))
    modds = pt.concatenate([pt.ones_like(modds[:,:1]),modds[:,:]],axis=1)

    odds = pt.set_subtensor(odds[:,[0,2,4]],modds)

    pm.Multinomial('ov',p=pm.math.softmax(odds), n=ns, observed = obs)

    pm_jax.sample_numpyro_nuts()

Error message:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/home/velochy/salk/salk_internal_package/experiments.ipynb Cell 1 line 2
     20 #odds = pt.set_subtensor(odds[:,[5,3,1]],modds)
     22 pm.Multinomial('ov',p=pm.math.softmax(odds), n=ns, observed = obs)
---> 24 pm_jax.sample_numpyro_nuts()

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/jax.py:567, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
    564     raise ValueError(f"{nuts_sampler=} not recognized")
    566 tic1 = datetime.now()
--> 567 raw_mcmc_samples, sample_stats, library = sampler_fn(
    568     model=model,
    569     target_accept=target_accept,
    570     tune=tune,
    571     draws=draws,
    572     chains=chains,
    573     chain_method=chain_method,
    574     progressbar=progressbar,
    575     random_seed=random_seed,
    576     initial_points=initial_points,
    577     nuts_kwargs=nuts_kwargs,
    578 )
    579 tic2 = datetime.now()
    581 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/jax.py:484, in _sample_numpyro_nuts(model, target_accept, tune, draws, chains, chain_method, progressbar, random_seed, initial_points, nuts_kwargs)
    481 if chains > 1:
    482     map_seed = jax.random.split(map_seed, chains)
--> 484 pmap_numpyro.run(
    485     map_seed,
    486     init_params=initial_points,
    487     extra_fields=(
    488         "num_steps",
    489         "potential_energy",
    490         "energy",
    491         "adapt_state.step_size",
    492         "accept_prob",
    493         "diverging",
    494     ),
    495 )
    497 raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
    498 sample_stats = _numpyro_stats_to_dict(pmap_numpyro)

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/mcmc.py:650, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    648     states, last_state = _laxmap(partial_map_fn, map_args)
    649 elif self.chain_method == "parallel":
--> 650     states, last_state = pmap(partial_map_fn)(map_args)
    651 else:
    652     assert self.chain_method == "vectorized"

    [... skipping hidden 12 frame]

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/mcmc.py:426, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
    424 # Check if _sample_fn is None, then we need to initialize the sampler.
    425 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 426     new_init_state = self.sampler.init(
    427         rng_key,
    428         self.num_warmup,
    429         init_params,
    430         model_args=args,
    431         model_kwargs=kwargs,
    432     )
    433     init_state = new_init_state if init_state is None else init_state
    434 sample_fn, postprocess_fn = self._get_cached_fns()

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/hmc.py:783, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    763 hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
    764     init_params,
    765     num_warmup=num_warmup,
   (...)
    780     rng_key=rng_key,
    781 )
    782 if is_prng_key(rng_key):
--> 783     init_state = hmc_init_fn(init_params, rng_key)
    784 else:
    785     # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
    786     # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
    787     # wa_steps because those variables do not depend on traced args: init_params, rng_key.
    788     init_state = vmap(hmc_init_fn)(init_params, rng_key)

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/hmc.py:763, in HMC.init.<locals>.<lambda>(init_params, rng_key)
    760         dense_mass = [tuple(sorted(z))] if dense_mass else []
    761     assert isinstance(dense_mass, list)
--> 763 hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
    764     init_params,
    765     num_warmup=num_warmup,
    766     step_size=self._step_size,
    767     num_steps=self._num_steps,
    768     inverse_mass_matrix=inverse_mass_matrix,
    769     adapt_step_size=self._adapt_step_size,
    770     adapt_mass_matrix=self._adapt_mass_matrix,
    771     dense_mass=dense_mass,
    772     target_accept_prob=self._target_accept_prob,
    773     trajectory_length=self._trajectory_length,
    774     max_tree_depth=self._max_tree_depth,
    775     find_heuristic_step_size=self._find_heuristic_step_size,
    776     forward_mode_differentiation=self._forward_mode_differentiation,
    777     regularize_mass_matrix=self._regularize_mass_matrix,
    778     model_args=model_args,
    779     model_kwargs=model_kwargs,
    780     rng_key=rng_key,
    781 )
    782 if is_prng_key(rng_key):
    783     init_state = hmc_init_fn(init_params, rng_key)

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/hmc.py:336, in hmc.<locals>.init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, num_steps, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, regularize_mass_matrix, model_args, model_kwargs, rng_key)
    334 r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
    335 vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
--> 336 vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
    337 energy = vv_state.potential_energy + kinetic_fn(
    338     wa_state.inverse_mass_matrix, vv_state.r
    339 )
    340 zero_int = jnp.array(0, dtype=jnp.result_type(int))

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/hmc_util.py:282, in velocity_verlet.<locals>.init_fn(z, r, potential_energy, z_grad)
    274 """
    275 :param z: Position of the particle.
    276 :param r: Momentum of the particle.
   (...)
    279 :return: initial state for the integrator.
    280 """
    281 if potential_energy is None or z_grad is None:
--> 282     potential_energy, z_grad = _value_and_grad(
    283         potential_fn, z, forward_mode_differentiation
    284     )
    285 return IntegratorState(z, r, potential_energy, z_grad)

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/hmc_util.py:250, in _value_and_grad(f, x, forward_mode_differentiation)
    248     return out, grads
    249 else:
--> 250     return value_and_grad(f, has_aux=False)(x)

    [... skipping hidden 8 frame]

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/jax.py:156, in get_jaxified_logp.<locals>.logp_fn_wrap(x)
    155 def logp_fn_wrap(x):
--> 156     return logp_fn(*x)[0]

File /tmp/tmpfrmeiqr6:11, in jax_funcified_fgraph(N)
      9 tensor_variable_3 = elemwise_fn_2(tensor_variable_2, tensor_constant_1)
     10 # Alloc([[1.]], 2, Sub.0)
---> 11 tensor_variable_4 = alloc(tensor_constant_2, tensor_constant_3, tensor_variable_3)
     12 # Join(1, Alloc.0, N)
     13 tensor_variable_5 = join(tensor_constant_4, tensor_variable_4, N)

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/pytensor/link/jax/dispatch/tensor_basic.py:47, in jax_funcify_Alloc.<locals>.alloc(x, *shape)
     46 def alloc(x, *shape):
---> 47     res = jnp.broadcast_to(x, shape)
     48     Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape)
     49     return res

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:1222, in broadcast_to(array, shape)
   1218 @util.implements(np.broadcast_to, lax_description="""\
   1219 The JAX version does not necessarily return a view of the input.
   1220 """)
   1221 def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array:
-> 1222   return util._broadcast_to(array, shape)

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/jax/_src/numpy/util.py:417, in _broadcast_to(arr, shape)
    415   shape = (shape,)
    416 # check that shape is concrete
--> 417 shape = core.canonicalize_shape(shape)  # type: ignore[arg-type]
    418 arr_shape = np.shape(arr)
    419 if core.definitely_equal_shape(arr_shape, shape):

File ~/anaconda3/envs/salk/lib/python3.12/site-packages/jax/_src/core.py:2117, in canonicalize_shape(shape, context)
   2115 except TypeError:
   2116   pass
-> 2117 raise _invalid_shape_error(shape, context)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (2, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function _single_chain_mcmc at /home/velochy/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/mcmc.py:422 for pmap. This value became a tracer due to JAX operations on these lines:

  operation a:bool[] = lt b c
    from line /tmp/tmpfrmeiqr6:5:24 (jax_funcified_fgraph)

  operation a:i64[] = pjit[
  name=_where
  jaxpr={ lambda ; b:bool[] c:i64[] d:i64[]. let
      e:i64[] = select_n b d c
    in (e,) }
] f g h
    from line /tmp/tmpfrmeiqr6:7:24 (jax_funcified_fgraph)

  operation a:i64[] = sub b c
    from line /tmp/tmpfrmeiqr6:9:24 (jax_funcified_fgraph)

PyMC version information:

Fails on 5.13.1

Context for the issue:

No response

ricardoV94 commented 2 months ago

Does using this helper first, fix the problem? https://github.com/pymc-devs/pymc/blob/4bc84391893f1face230ed64241a339d4d9dbf62/pymc/model/transform/optimization.py#L23

velochy commented 2 months ago

How would one use it?

ricardoV94 commented 2 months ago
with pm.Model() as model:
   ...

frozen_model = freeze_dims_and_data(model)
with frozen_model:
  pm.sample(nuts_sampler="numpyro")
velochy commented 2 months ago

Yes that works. Both for the toy example as well as the real model. Any downsides to doing that?

ricardoV94 commented 2 months ago

Yes that works. Both for the toy example as well as the real model. Any downsides to doing that?

No

velochy commented 2 months ago

Well there seems to be one. It is now throwing errors if I add initvals to the models. Any workarounds for that?

ricardoV94 commented 2 months ago

Can you provide a minimum working example?

velochy commented 2 months ago
import pymc as pm
from pymc.sampling import jax as pm_jax
import pytensor.tensor as pt
import numpy as np
from pymc.model.transform.optimization import freeze_dims_and_data

obs = np.array([
    [1,0,1,0,1,0],
    [0,1,1,0,1,0], 
])
ns = [3,3]

with pm.Model() as model:
    model.add_coord('mw',range(6))

    odds = pt.zeros( (len(ns),model.dim_lengths['mw']) )
    modds = pm.Normal('N',shape=(len(ns),model.dim_lengths['mw']//2 - 1),initval=obs[:,:2])
    modds = pt.concatenate([pt.ones_like(modds[:,:1]),modds[:,:]],axis=1)

    odds = pt.set_subtensor(odds[:,[0,2,4]],modds)
    #odds = pt.set_subtensor(odds[:,[5,3,1]],modds)

    pm.Multinomial('ov',p=pm.math.softmax(odds), n=ns, observed = obs)

frozen_model = freeze_dims_and_data(model)
with frozen_model:
    idata = pm_jax.sample_numpyro_nuts()

throws

NotImplementedError                       Traceback (most recent call last)
[/home/velochy/salk/salk_internal_package/experiments.ipynb](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/experiments.ipynb) Cell 1 line 2
     [21](vscode-notebook-cell:/home/velochy/salk/salk_internal_package/experiments.ipynb#W0sZmlsZQ%3D%3D?line=20)     #odds = pt.set_subtensor(odds[:,[5,3,1]],modds)
     [23](vscode-notebook-cell:/home/velochy/salk/salk_internal_package/experiments.ipynb#W0sZmlsZQ%3D%3D?line=22)     pm.Multinomial('ov',p=pm.math.softmax(odds), n=ns, observed = obs)
---> [25](vscode-notebook-cell:/home/velochy/salk/salk_internal_package/experiments.ipynb#W0sZmlsZQ%3D%3D?line=24) frozen_model = freeze_dims_and_data(model)
     [26](vscode-notebook-cell:/home/velochy/salk/salk_internal_package/experiments.ipynb#W0sZmlsZQ%3D%3D?line=25) with frozen_model:
     [27](vscode-notebook-cell:/home/velochy/salk/salk_internal_package/experiments.ipynb#W0sZmlsZQ%3D%3D?line=26)     idata = pm_jax.sample_numpyro_nuts()

File [~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:34](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:34), in freeze_dims_and_data(model)
     [23](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:23) def freeze_dims_and_data(model: Model) -> Model:
     [24](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:24)     """Recreate a Model with fixed RV dimensions and Data values.
     [25](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:25) 
     [26](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:26)     The dimensions of the pre-existing RVs will no longer follow changes to the coordinates.
   (...)
     [32](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:32)     are more restrictive about dynamic shapes such as JAX.
     [33](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:33)     """
---> [34](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:34)     fg, memo = fgraph_from_model(model)
     [36](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:36)     # Replace mutable dim lengths and data by constants
     [37](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:37)     frozen_vars = {
     [38](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:38)         memo[dim_length]: constant(
     [39](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:39)             dim_length.get_value(), name=dim_length.name, dtype=dim_length.type.dtype
   (...)
     [42](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:42)         if isinstance(dim_length, SharedVariable)
     [43](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:43)     }

File [~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:154](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:154), in fgraph_from_model(model, inlined_views)
    [132](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:132) """Convert Model to FunctionGraph.
    [133](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:133) 
    [134](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:134) See: model_from_fgraph
   (...)
    [150](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:150)     A dictionary mapping original model variables to the equivalent nodes in the fgraph.
    [151](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:151) """
    [153](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:153) if any(v is not None for v in model.rvs_to_initial_values.values()):
--> [154](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:154)     raise NotImplementedError("Cannot convert models with non-default initial_values")
    [156](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:156) if model.parent is not None:
    [157](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:157)     raise ValueError(
    [158](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:158)         "Nested sub-models cannot be converted to fgraph. Convert the parent model instead"
    [159](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:159)     )

NotImplementedError: Cannot convert models with non-default initial_values
ricardoV94 commented 2 months ago

Right we don't support custom initial values on the model transformations. You should be able to specify them when calling pm.sample instead.

ricardoV94 commented 2 months ago

Or specify them after freezing the model, with model.set_initval or something

velochy commented 2 months ago

ok. Maybe it makes sense to deprecate initval parameter on the RVs then?

ricardoV94 commented 2 months ago

I think there was resistance to that, but it would be my preference

velochy commented 2 months ago

Um. Resistance to removing something that no longer works? Or do I misunderstand something?

ricardoV94 commented 2 months ago

It works, just not for model transformations like freeze_dims_and_data

ricardoV94 commented 2 months ago

That's why it's a NotImplementedError, not a ValueError or something like that.

velochy commented 2 months ago

and jax sampler only works for frozen dims if we want more complex pytensor manipulations? So by implication jax sampler only works for complex models if you dont use initval with RVs. Is this by design?

ricardoV94 commented 2 months ago

Most models should sample fine in JAX without frozen dims, but we expect some hiccups like the one you found. Hence why that helper was added. That helper does not work with custom initvals, but you can pass custom initvals directly to sampler anyway.

It's not a final solution, but everyone should be able to do their thing right now.

ricardoV94 commented 2 months ago

Your model wouldn't have worked before the changes with mutable dims anyway.

You can change this line:

pt.zeros( (len(ns),model.dim_lengths['mw']) )

To:

pt.zeros( (len(ns), len(model.coords['mw']))

That will freeze the second dimension, instead of linking it to the mutable mw dim_length, which is the change from 5.12 to 5.13 that's hitting you.

velochy commented 2 months ago

Ok. I guess I have my answers, and you are right, I have all the tools needed to make it work. Thank you for your thorough answers @ricardoV94

jessegrabowski commented 2 months ago

The out of sample model pattern breaks in a JAX workflow due to the issues reported here. For example:

import pymc as pm
from pymc.model.transform.optimization import freeze_dims_and_data

with pm.Model(coords={'a':[1]}) as m:
    mu = pm.Normal('mu', dims=['a'])
    x = pm.Normal('x', mu=mu, sigma=1, dims=['a'])
    idata = pm.sample(nuts_sampler='numpyro')

with pm.Model(coords={'a':[1]}) as new_m:
    mu = pm.Flat('mu', dims=['a'])
    x = pm.Normal('x', mu, sigma=2, dims=['a'])

frozen_new_m = freeze_dims_and_data(new_m)
with frozen_new_m:
    idata_pred = pm.sample_posterior_predictive(idata, var_names=['x'], 
                                                predictions=True, 
                                                compile_kwargs={'mode':'JAX'})

I guess there's some automatic initivals being silently set when we use pm.sample_posterior_predictive in this way, which unexpectedly breaks freeze_dims_and_data.