itsdfish / SequentialSamplingModels.jl

A unified interface for simulating and evaluating sequential sampling models in Julia.
https://itsdfish.github.io/SequentialSamplingModels.jl/dev/
MIT License
27 stars 4 forks source link

Wald model: domain error (InverseGaussian: the condition μ > zero(μ) is not satisfied) #76

Closed DominiqueMakowski closed 3 months ago

DominiqueMakowski commented 3 months ago

I fitted a fairly simple Wald model, diagnostics look good, and it works when making predictions on the data. However, when I try generating predictions on new data, it fails with the following error:

julia> pred = predict(model_wald([(missing) for i in 1:length(grid)]; min_rt=minimum(df.RT), isi=grid), chain_wald)
ERROR: DomainError with -14.711588130987748:
InverseGaussian: the condition μ > zero(μ) is not satisfied.
Stacktrace:
  [1] #313
    @ C:\Users\domma\.julia\packages\Distributions\ji8PW\src\univariate\continuous\inversegaussian.jl:33 [inlined]
  [2] check_args
    @ C:\Users\domma\.julia\packages\Distributions\ji8PW\src\utils.jl:89 [inlined]
  [3] #InverseGaussian#312
    @ C:\Users\domma\.julia\packages\Distributions\ji8PW\src\univariate\continuous\inversegaussian.jl:33 [inlined]
  [4] InverseGaussian
    @ C:\Users\domma\.julia\packages\Distributions\ji8PW\src\univariate\continuous\inversegaussian.jl:32 [inlined]
  [5] rand
    @ C:\Users\domma\.julia\packages\SequentialSamplingModels\dYw8c\src\Wald.jl:66 [inlined]
  [6] init
    @ C:\Users\domma\.julia\packages\DynamicPPL\rXg4T\src\sampler.jl:24 [inlined]
  [7] assume(rng::Random.TaskLocalRNG, sampler::DynamicPPL.SampleFromPrior, dist::Wald{…}, vn::AbstractPPL.VarName{…}, vi::DynamicPPL.TypedVarInfo{…})
    @ DynamicPPL C:\Users\domma\.julia\packages\DynamicPPL\rXg4T\src\context_implementations.jl:234
  [8] tilde_assume
    @ C:\Users\domma\.julia\packages\DynamicPPL\rXg4T\src\context_implementations.jl:70 [inlined]
  [9] tilde_assume
    @ C:\Users\domma\.julia\packages\DynamicPPL\rXg4T\src\context_implementations.jl:67 [inlined]
 [10] tilde_assume
    @ C:\Users\domma\.julia\packages\DynamicPPL\rXg4T\src\context_implementations.jl:52 [inlined]
 [11] tilde_assume!!(context::DynamicPPL.SamplingContext{…}, right::Wald{…}, vn::AbstractPPL.VarName{…}, vi::DynamicPPL.TypedVarInfo{…})
    @ DynamicPPL C:\Users\domma\.julia\packages\DynamicPPL\rXg4T\src\context_implementations.jl:138
 [12] macro expansion
    @ c:\Users\domma\Dropbox\RECHERCHE\Studies\DoggoNogo\study1\analysis\models_nonrandom.jl:66 [inlined]
 [13] macro expansion
    @ C:\Users\domma\.julia\packages\DynamicPPL\rXg4T\src\compiler.jl:555 [inlined]
 [14] model_wald(__model__::DynamicPPL.Model{…}, __varinfo__::DynamicPPL.TypedVarInfo{…}, __context__::DynamicPPL.SamplingContext{…}, data::Vector{…}; min_rt::Float64, isi::Vector{…})
    @ Main c:\Users\domma\Dropbox\RECHERCHE\Studies\DoggoNogo\study1\analysis\models_nonrandom.jl:62
 [15] model_wald
    @ c:\Users\domma\Dropbox\RECHERCHE\Studies\DoggoNogo\study1\analysis\models_nonrandom.jl:48 [inlined]
 [16] _evaluate!!(model::DynamicPPL.Model{…}, varinfo::DynamicPPL.TypedVarInfo{…}, context::DynamicPPL.SamplingContext{…})
    @ DynamicPPL C:\Users\domma\.julia\packages\DynamicPPL\rXg4T\src\model.jl:963
 [17] evaluate_threadunsafe!!
    @ C:\Users\domma\.julia\packages\DynamicPPL\rXg4T\src\model.jl:936 [inlined]
 [18] evaluate!!
    @ C:\Users\domma\.julia\packages\DynamicPPL\rXg4T\src\model.jl:889 [inlined]
 [19] evaluate!! (repeats 2 times)
    @ C:\Users\domma\.julia\packages\DynamicPPL\rXg4T\src\model.jl:900 [inlined]
 [20] Model
    @ C:\Users\domma\.julia\packages\DynamicPPL\rXg4T\src\model.jl:860 [inlined]
 [21] #116
    @ C:\Users\domma\.julia\packages\Turing\iM84I\src\mcmc\Inference.jl:676 [inlined]
 [22] iterate
    @ .\generator.jl:47 [inlined]
 [23] collect_to!(dest::Matrix{…}, itr::Base.Generator{…}, offs::Int64, st::Tuple{…})
    @ Base .\array.jl:892
 [24] collect_to_with_first!(dest::Matrix{…}, v1::Turing.Inference.Transition{…}, itr::Base.Generator{…}, st::Tuple{…})
    @ Base .\array.jl:870
 [25] collect(itr::Base.Generator{Base.Iterators.ProductIterator{…}, Turing.Inference.var"#116#117"{…}})
    @ Base .\array.jl:844
 [26] map(f::Function, A::Base.Iterators.ProductIterator{Tuple{UnitRange{Int64}, UnitRange{Int64}}})
    @ Base .\abstractarray.jl:3313
 [27] #transitions_from_chain#115
    @ C:\Users\domma\.julia\packages\Turing\iM84I\src\mcmc\Inference.jl:673 [inlined]
 [28] predict(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, chain::Chains{…}; include_all::Bool)
    @ Turing.Inference C:\Users\domma\.julia\packages\Turing\iM84I\src\mcmc\Inference.jl:583
 [29] predict
    @ C:\Users\domma\.julia\packages\Turing\iM84I\src\mcmc\Inference.jl:576 [inlined]
 [30] #predict#108
    @ C:\Users\domma\.julia\packages\Turing\iM84I\src\mcmc\Inference.jl:574 [inlined]
 [31] predict(model::DynamicPPL.Model{…}, chain::Chains{…})
    @ Turing.Inference C:\Users\domma\.julia\packages\Turing\iM84I\src\mcmc\Inference.jl:573
 [32] top-level scope
    @ c:\Users\domma\Dropbox\RECHERCHE\Studies\DoggoNogo\study1\analysis\models_nonrandom.jl:119
Some type information was truncated. Use `show(err)` to see complete types.

The trace suggests that it could be caused by InverseGaussian?

Any thoughts?

Here's the model just in case, but in this case I wouldn't say the issue lies with it:

@model function model_wald(data; min_rt=minimum(data.rt), isi=nothing)

    # Transform ISI into polynomials
    isi = data_poly(isi, 2; orthogonal=true)

    # Priors for coefficients
    drift_intercept ~ truncated(Normal(5, 2), 0.0, Inf)
    drift_isi1 ~ Normal(0, 1)
    drift_isi2 ~ Normal(0, 1)

    # Priors
    α ~ truncated(Normal(0.5, 0.4), 0.0, Inf)
    τ ~ truncated(Normal(0.2, 0.05), 0.0, min_rt)

    for i in 1:length(data)
        drift = drift_intercept
        drift += drift_isi1 * isi[i, :1]
        drift += drift_isi2 * isi[i, :2]
        data[i] ~ Wald(drift, α, τ)
    end
end
itsdfish commented 3 months ago

My guess is that you might need to constrain the drift rate somehow to be positive, but I am perplexed that the same posterior samples used for estimation and prediction, but you only encounter an error with prediction. Maybe the first step is to verify that the drift rate is negative during sampling and prediction. Can you capture the parameters associated with negative drift rate during estimation and prediction? You can use something like:

if drift < 0
    println("drift_intercept $drift_intercept drift_isi1 $drift_isi1 drift_isi2 $drift_isi2 isi[i, :1] $(isi[i, :1]) isi[i, :2] $(isi[i, :2]))")
end

I can't remember if you need to modify the code above for Dual types. If you do and you get stuck, let me know and I'll try to help debug.

itsdfish commented 3 months ago

To follow up, I think you might need to load ForwardDiff to expose Dual and do something like this:

if isa(drift, Dual) && (drift < 0)

and use drift_intercept.value to access the parameter value. The code above might work, but I can't remember.

DominiqueMakowski commented 3 months ago

Actually today while making a MWE the model failed at sampling (so yesterday was just a statistical fluke probably)

Here's the code:

using CSV
using DataFrames
using Turing
using SequentialSamplingModels
using StatsModels
using StatsPlots
using CairoMakie
using Downloads
using JLD2
using Random
Random.seed!(12); # Set seed for reproducibility

include(Downloads.download("https://raw.githubusercontent.com/RealityBending/scripts/main/data_grid.jl"))
include(Downloads.download("https://raw.githubusercontent.com/RealityBending/scripts/main/data_poly.jl"))

# Data ==========================================================================================

df = CSV.read(Downloads.download("https://raw.githubusercontent.com/RealityBending/DoggoNogo/main/study1/data/data_game.csv"), DataFrame)

# Wald ------------------------------------------------------------------------------------------
@model function model_wald(data; min_rt=minimum(data.rt), isi=nothing)

    # Transform ISI into polynomials
    isi = data_poly(isi, 2; orthogonal=true)

    # Priors for coefficients
    drift_intercept ~ truncated(Normal(5, 2), 0.0, Inf)
    drift_isi1 ~ Normal(0, 1)
    drift_isi2 ~ Normal(0, 1)

    # Priors
    α ~ truncated(Normal(0.5, 0.4), 0.0, Inf)
    τ ~ truncated(Normal(0.2, 0.05), 0.0, min_rt)

    for i in 1:length(data)
        drift = drift_intercept
        drift += drift_isi1 * isi[i, :1]
        drift += drift_isi2 * isi[i, :2]
        data[i] ~ Wald(drift, α, τ)
    end
end

model = model_wald(df.RT, min_rt=minimum(df.RT), isi=df.ISI)
chain_wald = sample(model, NUTS(), 1000)

# Prediction
grid = data_grid(df.ISI)
pred = predict(model_wald([(missing) for i in 1:length(grid)]; min_rt=minimum(df.RT), isi=grid), chain_wald)

What would be the best way to constrain the drift rate to be > 0?

Do you think it would make sense to clamp it down like so:

    for i in 1:length(data)
        drift = drift_intercept
        drift += drift_isi1 * isi[i, :1]
        drift += drift_isi2 * isi[i, :2]
        if drift < 0
            drift = 0.0
        end
        data[i] ~ Wald(drift, α, τ)
    end
itsdfish commented 3 months ago

Ok. That makes more sense now.

I think your idea of forcing it to be non-negative is good if you want your slopes to assume negative values. You can also use drift = max(0, drift). I'm not sure if there is a reason to prefer one over the other. The other option (assuming isi is positive), you can constrain your slopes to be positive:

drift_isi1 ~ truncated(Normal(0, 1), 0.0, Inf)
drift_isi2 ~ truncated(Normal(0, 1), 0.0, Inf)

There are other transformations from R to R+, but they are non-linear and would change the functional form and interpretation of your model.

DominiqueMakowski commented 3 months ago

The other option (assuming isi is positive), you can constrain your slopes

That probably wouldn't work, also the Inter-Stimulus Interval is always positive, its effect (slope) on RT is typically negative (i.e., very short ISIs -> higher RT). Ideally, I would like to not constrain the effects directly too much but just "discard all combinations of parameters - basically having a constraint on the join-parameter space - that result in negative drift" but I think that's not really feasible😅

itsdfish commented 3 months ago

Got it. Perhaps exponential decay would not be unreasonable: beta_0 exp(-beta_1 isi_1 - beta_2 * isi_2)

DominiqueMakowski commented 3 months ago

Interesting idea, thanks! I'll close it as it's more a model specs issue than a bug