Open jessegrabowski opened 5 months ago
https://github.com/pymc-devs/pymc/issues/7268 may be a bit of a blocker.
Right now we can't do model transformations inside another model context (because I cannot create a distinct model inside another model). I would love to get rid of this behavior, but that means breaking the nested model API. The nested model API is an overkill for basically auto-variable name prefixing...
Another point to consider is https://github.com/pymc-devs/pymc/discussions/7177 . For caching to be useable, users have to have control over when and how to freeze their models, and we shouldn't interfere too much by transforming models under the hood on behalf of users.
We could perhaps establish a compromise where we transform if the backend is JAX, and not otherwise.
Or perhaps we re-introduce pm.ConstantData
. I would however not rename pm.Data
to MutableData
and perhaps not reintroduce any mutable
kwargs? So most users don't have to see/think about it. Those that want to sample with JAX can go directly to ConstantData
. I agree that it's a bit cumbersome to force users who want that to first define the model and then call freeze_rv_and_dims. How would they find about ConstatData
though?
The initval
thing... Flat/HalfFlat variables should no longer need custom initvals? If we are still putting those we should remove it. finite_logp_point
should works just fine for them. We can support models with custom initvals, I just dislike them at the model level and didn't want to bother representing them in the fgraph format. I still think they should just be passed to pm.sample
when a custom initval is needed.
I had in mind that it should be possible to the necessary replacements on the forward sampling graph directly, like model -> forward_graph -> frozen_forward_graph -> sampling_function, without making the round trip model -> fgraph -> frozen_model -> frozen_forward_graph -> sampling_function. Is this not possible? I was thinking we just need to extract all the necessary shape information from the model then apply rewrites that are essentially just pt.specify_shape
Re: CosntantData
, what we're doing now is actually better, I don't think it's correct to go backwards. It basically forces users to choose between pm.set_data
and JAX-forward sampling, which doesn't seem like something they should have to do. There should definitely be a way to have both.
For me, #7177 is more about allowing the use of static_argnums
in the JAX functions we generate. Maybe that would solve a lot of our problems? In 99% of uses, the JAX function is completely re-generated each time it is needed anyway, so we essentially lose nothing by specifying inputs with unknown shape as static_argnums
. I think we've talked about this before, but your answer didn't stick with me.
I had in mind that it should be possible to the necessary replacements on the forward sampling graph directly, like model -> forward_graph -> frozen_forward_graph -> sampling_function, without making the round trip model -> fgraph -> frozen_model -> frozen_forward_graph -> sampling_function. Is this not possible? I was thinking we just need to extract all the necessary shape information from the model then apply rewrites that are essentially just pt.specify_shape
I'm afraid of adding special logic inside the forward samplers that is JAX specific. I preferred the freeze_rv_and_dims
route because is reusing generic code that has other applications.
RE: Static argnums, we should just try it. Basically we need to replace a vector input by a tuple of scalars all of which are static_argnums in the compiled function (since numpy arrays aren't hashable and accepted as static_argnums). I think that may still be a clean solution that doesn't require us to do anything here.
I guess I had in mind that this hypothetical graph operation could just replace freeze_rv_and_dims
. It's sort of a hack to pass go back to a model representation when we could just operate on the graph (logp or forward sampling) directly. From a rewrite perspective, there shouldn't be any difference between these, no?
But I agree that the whole nesting thing is a bit of a PITA ,and it might just be better to attack that somehow.
freeze_rv_and_dims
inside a model without worries.The question of why to go back from fgraph to model is something we can start tackling in #7268. Potentially pm.Model
could become a more thin shell that just builds the corresponding fgraph
under the hood, but that may take sometime to do properly (i.e., not break everything accidentally).
I love the idea @jessegrabowski ! That would definitely make that part of the API clearer to users, and I think it's welcome.
I wanna make sure I understand what freeze_dims_and_data
does though: will that be a problem when users call pm.set_data
with out-of-sample data before calling pm.sample_posterior_predictive
?
And is that also needed with Numba mode?
freeze_dims_and_data
is only required for JAX, because it doesn't allow dynamic array shapes inside JIT functions. This isn't the case for numba jit functions (you can @njit
a function then pass whatever shapes you like), so it wouldn't be necessary there. Not sure about pytorch.
freeze_dims_and_data(model)
creates a copy of model
, with all shape information inserted from available data/coords. So this code would not work with pm.set_data
:
with freeze_dims_and_data(model):
pm.set_data({'X':df_test.values}, coords={'obs_idx':df_test.index})
idata_pred = pm.sample_posterior_predictive(idata)
It would fail with a shape error since the model was first frozen, then pm.set_data
was called. But it would work if you do things in the other order: first update the data that will be used by freeze
to determine all the static shapes, then call freeze:
with model:
pm.set_data({'X':df_test.values}, coords={'obs_idx':df_test.index})
with freeze_dims_and_data(model):
idata_pred = pm.sample_posterior_predictive(idata)
Actually now that https://github.com/pymc-devs/pymc/pull/7352 is merged, you can even do this:
with model:
pm.set_data({'X':df_test.values}, coords={'obs_idx':df_test.index})
with freeze_dims_and_data(model):
idata_pred = pm.sample_posterior_predictive(idata)
Which is essentially what I envision happening automatically inside pm.sample_*_predictive
if you pass mode="JAX"
Should be fixed by https://github.com/pymc-devs/pytensor/pull/1029
Before
After
Context for the issue:
For models involving scan or multivariate-normal distributions, you get big speedups by passing
compile_kwargs={'mode':'JAX'}
topm.sample_prior_predictive
andpm.sample_posterior_predictive
. This has already proven useful in statespace modeling (in pymc-experimental, https://github.com/pymc-devs/pymc-experimental/pull/346) and instrumental variable modeling (in casualpy, https://github.com/pymc-labs/CausalPy/pull/345). In each of these cases using the JAX backend offers significant speedups, and is a highly desirable feature.This was technically never a supported feature, but it could be made to work by consciously specifying the whole model to be static (e.g. using
pm.ConstantData
and avoidingmutable_kwargs
). After #7047 this is obviously no longer possible. The work-around is to usefreeze_dims_and_data
, but this is somewhat cumbersome, especially with prior predictive sampling, where a typical workflow callspm.sample_prior_predictive
in the model block at construction time. I have also come up with cases wherefreeze_dims_and_data
fails. A trivial example is in predictive modeling usingpm.Flat
dummies -- this adds non-None
entries tomodel.rvs_to_initial_values
, causingmodel_to_fgraph
to fail.My proposal would be to simply add a "freezing" step to
compile_forward_sampling_function
. This would alleviate the need for users to be aware of thefreeze_dims_and_data
helper function, allow JAX forward sampling without breaking out of a single model context, and also support any future backend that requires all shape information to be known.I would also propose to officially support and expose alternative forward sampling backends by promoting
backend=
ormode=
to a kwarg inpm.sample_*_predictive
, rather than hiding it insidecompile_kwargs
.