lnccbrown / HSSM

Development of HSSM package
Other
82 stars 11 forks source link

Cannot replicate graph in "One parameter is a Regression Target" in main tutorial #487

Closed ThomasMurray14 closed 3 months ago

ThomasMurray14 commented 4 months ago

Hi! Apologies if this has been posted several times, I've had some troubles raising issues with my github account.

I've been following along with the main tutorial, and I can't reproduce the output in this section: https://lnccbrown.github.io/HSSM/tutorials/main_tutorial/#case-1-one-parameter-is-a-regression-target

From the graph (in my code), it looks like the v parameter is not set by the regression coefficients, but rather those coefficients are treated the same as the other parameters:

ddm

If I try to estimate the model anyway, the v distribution looks mental:

bad_ddm

I've installed HSSM (0.2.1) in a conda environment, following the installation instructions. The only additional thing I've done is install spyder. I also installed pymc=5.10 (following the instructions), as the graph visualisation didn't work with the latest version

My code is the same as on the tutorial (except I've removed the theta parameter from the last column of the matrix, as the ddm model crashed when it was included):

# Set up trial by trial parameters. v is linear combination of x and y. Beta for x = 0.8, beta for y = 0.3. Intercept = 0.3
intercept = 0.3
x = np.random.uniform(-1, 1, size=1000)
y = np.random.uniform(-1, 1, size=1000)
v = intercept + (0.8 * x) + (0.3 * y)

true_values = np.column_stack(
    [v, np.repeat([[1.5, 0.5, 0.2]], axis=0, repeats=1000)]
)

# Get model simulations
obs_ddm_reg_v = simulator(true_values, model="ddm", n_samples=1)
dataset_reg_v = pd.DataFrame(
    {
        "rt": obs_ddm_reg_v["rts"].flatten(),
        "response": obs_ddm_reg_v["choices"].flatten(),
        "x": x,
        "y": y,
    }
)

model_reg_v = hssm.HSSM(
    model="ddm",
    data=dataset_reg_v,
    include=[
        {
            "name": "v",
            "formula": "v ~ 1 + x + y",
            "prior": {
                "Intercept": {"name": "Uniform", "lower": -3.0, "upper": 3.0},
                "x": {"name": "Uniform", "lower": -1.0, "upper": 1.0},
                "y": {"name": "Uniform", "lower": -1.0, "upper": 1.0},
            },
            "link": "identity",
        }
    ],
)

model_reg_v.graph()
digicosmos86 commented 4 months ago

Hi @ThomasMurray14! Thank for reporting this issue! We are aware of this issue. Recent updates in PyMC and Bambi has introduced some breaking changes in their api and we are trying to update HSSM to accommodate these changes. Fixing graphing issues is our priority. Will let you know when this is fixed.