itsdfish / SequentialSamplingModels.jl

A unified interface for simulating and evaluating sequential sampling models in Julia.
https://itsdfish.github.io/SequentialSamplingModels.jl/dev/
MIT License
27 stars 4 forks source link

predict() broken? #48

Closed DominiqueMakowski closed 12 months ago

DominiqueMakowski commented 12 months ago

I am encountering some problems with predict() despite using some code that previously worked (?), do you manage to reproduce the error:

using Turing
using SequentialSamplingModels
using Random
using LinearAlgebra
using StatsPlots
using Random
using Pigeons

Random.seed!(45461)

# Generate some data with known parameters
dist = LBA(ν=[3.0, 2.0], A=0.8, k=0.2, τ=0.3)
data = rand(dist, 100)

# Specify LBA model
@model function model_lba(data; min_rt=0.2)
    # Priors
    ν ~ MvNormal(zeros(2), I * 2)
    A ~ truncated(Normal(0.8, 0.4), 0.0, Inf)
    k ~ truncated(Normal(0.2, 0.2), 0.0, Inf)
    τ ~ Uniform(0.0, min_rt)

    # Likelihood
    data ~ LBA(; ν, A, k, τ)
end

chain = sample(model_lba(data; min_rt=minimum(data.rt)), NUTS(), 1000)

dat0 = [(missing) for i in 1:length(data.rt)]
pred = predict(model_lba(dat0; min_rt=minimum(data.rt)), chain)

2nd issue, with different error:

using DataFrames
using Random
using SequentialSamplingModels
using StatsModels
using Turing

# Generate data with different drifts for two conditions A vs. B
function make_data(; difference=1, n_groups=5, n_obs=100)
    n_obs_pergroup = n_obs ÷ n_groups

    # Create the first group (baseline / intercept)
    drift = [1.5, 0.5]  # Baseline
    df = DataFrame(rand(LBA(ν=drift, A=0.5, k=0.2, τ=0.3), n_obs_pergroup))

    # Compute the parameter change for each group
    change_drift = [
        difference / (n_groups - 1),
        difference / (n_groups - 1) / 2
    ]

    # Add new groups
    for g in 2:n_groups
        drift += change_drift  # new drift
        df = vcat(df, DataFrame(rand(LBA(ν=drift, A=0.5, k=0.2, τ=0.3), n_obs_pergroup)))
    end

    # Add condition column (if less than 5 groups, use letters for categorical groups)
    # Otherwise, assume continuous
    if n_groups <= 4
        df.x = repeat(["A", "B", "C", "D"][1:n_groups], inner=n_obs_pergroup)
    else
        df.x = repeat(range(0, 1, length=n_groups), inner=n_obs_pergroup)
    end

    return df
end

# Define models
@model function model_lba(data; min_rt=0.2, x=nothing)
    # Priors for auxiliary parameters
    A ~ truncated(Normal(0.8, 0.4), 0.0, Inf)
    tau ~ Uniform(0.0, min_rt)
    k ~ truncated(Normal(0.2, 0.2), 0.0, Inf)

    # Priors for coefficients
    drift_intercept ~ filldist(Normal(0, 1), 2)
    drift_x ~ filldist(Normal(0, 1), 2)

    for i in 1:length(data)
        drifts = drift_intercept .+ drift_x * x[i]
        data[i] ~ LBA(; ν=drifts, τ=tau, A=A, k=k)
    end
end

# Generate data
Random.seed!(6)
df = make_data(difference=1, n_groups=2, n_obs=100)

# Format input data
f = @formula(rt ~ 1 + x)
f = apply_schema(f, schema(f, df))
_, predictors = coefnames(f)
X = modelmatrix(f, df)

# Sample model
data = [(choice=df.choice[i], rt=df.rt[i]) for i in 1:nrow(df)]  # Format data
chain = sample(model_lba(data; min_rt=minimum(df.rt), x=X[:, 2]), NUTS(), 1000)

# Predict
dat0 = [(missing) for i in 1:nrow(df)]
pred = predict(model_lba(dat0; min_rt=minimum(df.rt), x=X[:, 2]), chain)
itsdfish commented 12 months ago

Can you tell me which version of Turing and SequentialSamplingModels you are using?

itsdfish commented 12 months ago

Its difficult to diagnose without version info. Nonetheless, on my development branch turingext, both versions work but there was a version conflict with Pigeons. Here is what I found:

In your first example, you pass an array of missing. What you need is data0=missing because data ~ LBA(; ν, A, k, τ) is operating on an entire dataset, whereas, in the second example, you loop through and correctly access each missing element.

itsdfish commented 12 months ago

As soon as I figure out the version conflict issue, I will push a new version which should work better with predict

DominiqueMakowski commented 12 months ago

Damn apologies I ran update so I assumed I had the latest but turns out for some reason it downgraded it all the way to 0.3.4... Sorry should have checked before opening the issue :)

itsdfish commented 12 months ago

No problem!

Let me know if it still doesn't work with a newer version.

DominiqueMakowski commented 12 months ago

It does, all good