Open velochy opened 2 months ago
Does using this helper first, fix the problem? https://github.com/pymc-devs/pymc/blob/4bc84391893f1face230ed64241a339d4d9dbf62/pymc/model/transform/optimization.py#L23
How would one use it?
with pm.Model() as model:
...
frozen_model = freeze_dims_and_data(model)
with frozen_model:
pm.sample(nuts_sampler="numpyro")
Yes that works. Both for the toy example as well as the real model. Any downsides to doing that?
Yes that works. Both for the toy example as well as the real model. Any downsides to doing that?
No
Well there seems to be one. It is now throwing errors if I add initvals to the models. Any workarounds for that?
Can you provide a minimum working example?
import pymc as pm
from pymc.sampling import jax as pm_jax
import pytensor.tensor as pt
import numpy as np
from pymc.model.transform.optimization import freeze_dims_and_data
obs = np.array([
[1,0,1,0,1,0],
[0,1,1,0,1,0],
])
ns = [3,3]
with pm.Model() as model:
model.add_coord('mw',range(6))
odds = pt.zeros( (len(ns),model.dim_lengths['mw']) )
modds = pm.Normal('N',shape=(len(ns),model.dim_lengths['mw']//2 - 1),initval=obs[:,:2])
modds = pt.concatenate([pt.ones_like(modds[:,:1]),modds[:,:]],axis=1)
odds = pt.set_subtensor(odds[:,[0,2,4]],modds)
#odds = pt.set_subtensor(odds[:,[5,3,1]],modds)
pm.Multinomial('ov',p=pm.math.softmax(odds), n=ns, observed = obs)
frozen_model = freeze_dims_and_data(model)
with frozen_model:
idata = pm_jax.sample_numpyro_nuts()
throws
NotImplementedError Traceback (most recent call last)
[/home/velochy/salk/salk_internal_package/experiments.ipynb](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/experiments.ipynb) Cell 1 line 2
[21](vscode-notebook-cell:/home/velochy/salk/salk_internal_package/experiments.ipynb#W0sZmlsZQ%3D%3D?line=20) #odds = pt.set_subtensor(odds[:,[5,3,1]],modds)
[23](vscode-notebook-cell:/home/velochy/salk/salk_internal_package/experiments.ipynb#W0sZmlsZQ%3D%3D?line=22) pm.Multinomial('ov',p=pm.math.softmax(odds), n=ns, observed = obs)
---> [25](vscode-notebook-cell:/home/velochy/salk/salk_internal_package/experiments.ipynb#W0sZmlsZQ%3D%3D?line=24) frozen_model = freeze_dims_and_data(model)
[26](vscode-notebook-cell:/home/velochy/salk/salk_internal_package/experiments.ipynb#W0sZmlsZQ%3D%3D?line=25) with frozen_model:
[27](vscode-notebook-cell:/home/velochy/salk/salk_internal_package/experiments.ipynb#W0sZmlsZQ%3D%3D?line=26) idata = pm_jax.sample_numpyro_nuts()
File [~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:34](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:34), in freeze_dims_and_data(model)
[23](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:23) def freeze_dims_and_data(model: Model) -> Model:
[24](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:24) """Recreate a Model with fixed RV dimensions and Data values.
[25](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:25)
[26](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:26) The dimensions of the pre-existing RVs will no longer follow changes to the coordinates.
(...)
[32](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:32) are more restrictive about dynamic shapes such as JAX.
[33](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:33) """
---> [34](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:34) fg, memo = fgraph_from_model(model)
[36](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:36) # Replace mutable dim lengths and data by constants
[37](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:37) frozen_vars = {
[38](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:38) memo[dim_length]: constant(
[39](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:39) dim_length.get_value(), name=dim_length.name, dtype=dim_length.type.dtype
(...)
[42](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:42) if isinstance(dim_length, SharedVariable)
[43](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:43) }
File [~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:154](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:154), in fgraph_from_model(model, inlined_views)
[132](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:132) """Convert Model to FunctionGraph.
[133](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:133)
[134](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:134) See: model_from_fgraph
(...)
[150](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:150) A dictionary mapping original model variables to the equivalent nodes in the fgraph.
[151](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:151) """
[153](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:153) if any(v is not None for v in model.rvs_to_initial_values.values()):
--> [154](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:154) raise NotImplementedError("Cannot convert models with non-default initial_values")
[156](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:156) if model.parent is not None:
[157](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:157) raise ValueError(
[158](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:158) "Nested sub-models cannot be converted to fgraph. Convert the parent model instead"
[159](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:159) )
NotImplementedError: Cannot convert models with non-default initial_values
Right we don't support custom initial values on the model transformations. You should be able to specify them when calling pm.sample
instead.
Or specify them after freezing the model, with model.set_initval
or something
ok. Maybe it makes sense to deprecate initval parameter on the RVs then?
I think there was resistance to that, but it would be my preference
Um. Resistance to removing something that no longer works? Or do I misunderstand something?
It works, just not for model transformations like freeze_dims_and_data
That's why it's a NotImplementedError
, not a ValueError
or something like that.
and jax sampler only works for frozen dims if we want more complex pytensor manipulations? So by implication jax sampler only works for complex models if you dont use initval with RVs. Is this by design?
Most models should sample fine in JAX without frozen dims, but we expect some hiccups like the one you found. Hence why that helper was added. That helper does not work with custom initvals, but you can pass custom initvals directly to sampler anyway.
It's not a final solution, but everyone should be able to do their thing right now.
Your model wouldn't have worked before the changes with mutable dims anyway.
You can change this line:
pt.zeros( (len(ns),model.dim_lengths['mw']) )
To:
pt.zeros( (len(ns), len(model.coords['mw']))
That will freeze the second dimension, instead of linking it to the mutable mw
dim_length, which is the change from 5.12 to 5.13 that's hitting you.
Ok. I guess I have my answers, and you are right, I have all the tools needed to make it work. Thank you for your thorough answers @ricardoV94
The out of sample model pattern breaks in a JAX workflow due to the issues reported here. For example:
import pymc as pm
from pymc.model.transform.optimization import freeze_dims_and_data
with pm.Model(coords={'a':[1]}) as m:
mu = pm.Normal('mu', dims=['a'])
x = pm.Normal('x', mu=mu, sigma=1, dims=['a'])
idata = pm.sample(nuts_sampler='numpyro')
with pm.Model(coords={'a':[1]}) as new_m:
mu = pm.Flat('mu', dims=['a'])
x = pm.Normal('x', mu, sigma=2, dims=['a'])
frozen_new_m = freeze_dims_and_data(new_m)
with frozen_new_m:
idata_pred = pm.sample_posterior_predictive(idata, var_names=['x'],
predictions=True,
compile_kwargs={'mode':'JAX'})
I guess there's some automatic initivals being silently set when we use pm.sample_posterior_predictive
in this way, which unexpectedly breaks freeze_dims_and_data
.
Describe the issue:
I have a few models where I have to do some rather complex tensor manipulation, and moving from 5.12 to 5.13 quite a few of them broke down with JAX errors.
As the models themselves are big and unwieldy, I have tried to re-create the same issue with a toy example. As you can see, it needs to be quite convoluted to illicit the error (requiring a model dimension, a call to pt.concatenate and pt.set_subtensor), but I do run into it with more complex actual use cases as well.
I have managed to work around it i some cases by avoiding pt.concatenate and instead just creating an empty tensor and setting it's parts via set_subtensor, but I have one model where even that runs into issues. So it would be very nice if it worked like it used to before :)
The facts of the case:
Reproduceable code example:
Error message:
PyMC version information:
Fails on 5.13.1
Context for the issue:
No response