lnccbrown / HSSM

Development of HSSM package
Other
82 stars 11 forks source link

Cannot replicate graph in "One parameter is a Regression Target" in main tutorial #486

Closed ThomasMurray14 closed 4 months ago

ThomasMurray14 commented 4 months ago

Hi! I've been following along with the main tutorial, and I can't reproduce the output in this section: https://lnccbrown.github.io/HSSM/tutorials/main_tutorial/#case-1-one-parameter-is-a-regression-target

From the graph (in my code), it looks like the v parameter is not set by the regression coefficients, but rather those coefficients are treated the same as the other parameters:

ddm

If I try to estimate the model anyway, the v distribution looks mental:

bad_ddm

I've installed HSSM (0.2.1) in a conda environment, following the installation instructions. The only additional thing I've done is install spyder. I also installed pymc=5.10 (following the instructions), as the graph visualisation didn't work with the latest version

My code is the same as on the tutorial (except I've removed the theta parameter from the last column of the matrix, as the ddm model crashed when it was included):

# Set up trial by trial parameters. v is linear combination of x and y. Beta for x = 0.8, beta for y = 0.3. Intercept = 0.3
intercept = 0.3
x = np.random.uniform(-1, 1, size=1000)
y = np.random.uniform(-1, 1, size=1000)
v = intercept + (0.8 * x) + (0.3 * y)

true_values = np.column_stack(
    [v, np.repeat([[1.5, 0.5, 0.2]], axis=0, repeats=1000)]
)

# Get model simulations
obs_ddm_reg_v = simulator(true_values, model="ddm", n_samples=1)
dataset_reg_v = pd.DataFrame(
    {
        "rt": obs_ddm_reg_v["rts"].flatten(),
        "response": obs_ddm_reg_v["choices"].flatten(),
        "x": x,
        "y": y,
    }
)

model_reg_v = hssm.HSSM(
    model="ddm",
    data=dataset_reg_v,
    include=[
        {
            "name": "v",
            "formula": "v ~ 1 + x + y",
            "prior": {
                "Intercept": {"name": "Uniform", "lower": -3.0, "upper": 3.0},
                "x": {"name": "Uniform", "lower": -1.0, "upper": 1.0},
                "y": {"name": "Uniform", "lower": -1.0, "upper": 1.0},
            },
            "link": "identity",
        }
    ],
)

model_reg_v.graph()
digicosmos86 commented 4 months ago

Duplicate of #459