pymc-devs / pymc-bart

https://www.pymc.io/projects/bart
Other
87 stars 17 forks source link

`BART` crashing with `MutableData` #137

Open DanielRobertNicoud opened 9 months ago

DanielRobertNicoud commented 9 months ago

Describe the bug pymc_bart.BART fails to run when passing a MutableData response variable.

To Reproduce

import numpy as np
import pymc
import pymc_bart as pmb

X_train = np.random.normal(size=(100, 1))
y_train = np.random.normal(size=(100,))

with pymc.Model() as bart:
    # data containers
    X = pymc.MutableData("X", X_train)
    y = pymc.MutableData("y", y_train)
    # prior
    mu = pmb.BART("mu", X=X, Y=y, m=20)
    # sigma = pymc.HalfCauchy("sigma", beta=10)
    # likelihood
    likelihood = pymc.Normal("obs", mu=mu, sigma=.3, observed=y)

    idata = pymc.sample(random_seed=42)

Passing pmb.BART("mu", X=X, Y=y_train, m=20) instead works.

Expected behavior The model should run normally.

Additional context

pymc==5.10.1
pymc-bart==0.5.7
jabrantley commented 7 months ago

Not sure if you have resolved this issue or not. I have also had that issue in the past, but I don't think you need to have y as Mutable.

To modify your example:

import numpy as np
import pymc
import pymc_bart as pmb

X_train = np.random.normal(size=(100, 1))
y_train = np.random.normal(size=(100,))

with pymc.Model() as bart:
    # data containers
    X = pymc.MutableData("X", X_train)
    # y = pymc.MutableData("y", y_train)

    # prior
    mu = pmb.BART("mu", X=X, Y=np.log(y_train), m=20)
    # sigma = pymc.HalfCauchy("sigma", beta=10)

    # likelihood
    _mu = pm.math.exp(mu)
    likelihood = pymc.Normal("obs", mu=_mu, sigma=.3, observed=y_train, shape=_mu.shape) 

    # Sample
    idata = pymc.sample(random_seed=42)

Then sampling from the posterior to make predictions works just fine since we defined shape = _mu.shape.

X_test = np.random.normal(size=(75, 1))
y_test = np.random.normal(size=(75,))

with bart:
    X.set_value(X_test)
    predict = pm.sample_posterior_predictive(idata, predictions=True, random_seed=42)

I guess this does not answer the question of why that occurs, but this how I have been using BART in my work.