arviz-devs / arviz

Exploratory analysis of Bayesian models with Python
https://python.arviz.org
Apache License 2.0
1.56k stars 388 forks source link

`InferenceData.to_dataframe()` can crash the kernel #2258

Open drbenvincent opened 1 year ago

drbenvincent commented 1 year ago

Describe the bug When sampling from a model, and trying to extract the samples into a dataframe, I'm finding that it will crash the kernel.

To Reproduce

import arviz as az
import numpy as np
import pymc as pm

N = 10_000

with pm.Model() as model1:
    x = pm.Normal("x", shape=N)
    temp = pm.Normal("temp", shape=N)
    y = pm.Deterministic("y", x + 1 + np.sqrt(3) * temp)
    idata1 = pm.sample_prior_predictive(samples=1)

df = az.extract(idata1, group="prior", var_names=["x", "y"]).squeeze().to_dataframe()
df

When N=10_000, this works on my machine and gives the following result. Screenshot 2023-07-02 at 09 49 26

NOTE: that the dataframe has 100 million rows

But much above this leads to a kernel crash. In my use case I am using N=100_000.

This is related to the strategy of setting the shape of the variables to N then drawing 1 sample. For example, if we change the shape to the default 1, and instead ask for N samples, then it works fine. This does work as expected:

with pm.Model() as model2:
    x = pm.Normal("x")
    temp = pm.Normal("temp")
    y = pm.Deterministic("y", x + 1 + np.sqrt(3) * temp)
    idata2 = pm.sample_prior_predictive(samples=N)

df = az.extract(idata2, group="prior", var_names=["x", "y"]).squeeze().to_dataframe()
df

Expected behavior This may well be an edge case, but the expected behaviour would be to get a dataframe with N rows, even with the original model1 strategy.

Additional context Arviz version: 0.15.1

ahartikainen commented 1 year ago

This is expected behavior with xarray.Dataset (data is in tidy / long -format).

If you want data in wide format, use the to_dataframe from InferenceData object.

idata2.to_dataframe(groups="prior")[["x", "y"]]

We should add var_names to the functionality and also add a function to support xarray Datasets.