pymc-devs / pymc-experimental

https://pymc-experimental.readthedocs.io
Other
77 stars 49 forks source link

Add function that caches sampling results #277

Open ricardoV94 opened 10 months ago

ricardoV94 commented 10 months ago
import pymc as pm
from pymc_experimental.utils.cache import cache_sampling

with pm.Model() as m:
    y_data = pm.MutableData("y_data", [0, 1, 2])
    x = pm.Normal("x", 0, 1)
    y = pm.Normal("y", mu=x, observed=y_data)

    cache_sample = cache_sampling(pm.sample, dir="traces")
    idata1 = cache_sample(chains=2)

    # Cache hit! Returning stored result
    idata2 = cache_sample(chains=2)

    pm.set_data({"y_data": [1, 1, 1]})
    idata3 = cache_sample(chains=2)

assert idata1.posterior["x"].mean() == idata2.posterior["x"].mean()
assert idata1.posterior["x"].mean() != idata3.posterior["x"].mean()
twiecki commented 10 months ago

When would that be useful?

ricardoV94 commented 10 months ago

When rerunning notebooks or any workflow with saving/loading of traces where you might still be tinkering with the model.

You don't need to bother defining the names of the traces, or overriding old traces, since caching is automatically derived from the model and its data

fonnesbeck commented 10 months ago

I usually rely on things like MLFlow for storing artifacts like this.

ricardoV94 commented 10 months ago

I'm not familiar with MLflow, the idea here is that it pairs the saved traces to the exact model/sampling function (and arguments) that were used.

Basically the model and the function kwargs are the cache key.

Does this have any parallel to your workflow with MLflow?