itsdfish / SequentialSamplingModels.jl

A unified interface for simulating and evaluating sequential sampling models in Julia.
https://itsdfish.github.io/SequentialSamplingModels.jl/dev/
MIT License
27 stars 4 forks source link

Adding Shifted Lognormal #78

Closed DominiqueMakowski closed 4 months ago

DominiqueMakowski commented 4 months ago

I reckon users of this package might look for another typically used model for RT-only, the shifted lognormal distribution.

It seems like most of the implementation pieces are already in Distributions.jl, but I think it'd useful to add a SSM-friendly wrapper around it (with consistent argument names) for convenience. What do you think?

itsdfish commented 4 months ago

Does the lognormal race model accomplish your goal? The finishing times are lognormally distributed for each accumulator, and the non-decision time serves as the shift constant.

itsdfish commented 4 months ago

If you want a common SSM1D interface across your models, you can use this:

using SequentialSamplingModels
import Distributions: logpdf 
import Distributions: rand 

struct ShiftedLogNormal{T <: Real} <: SSM1D
    ν::T 
    σ::T 
    τ::T
end

ShiftedLogNormal(ν, σ, τ) = ShiftedLogNormal(promote(ν, σ, τ)...)

ShiftedLogNormal(; ν, σ, τ) = ShiftedLogNormal(ν, σ, τ)

function logpdf(dist::ShiftedLogNormal, rt)
    (; τ, ν, σ) = dist 
    return logpdf(LogNormal(ν, σ), rt - τ)
end

function rand(dist::ShiftedLogNormal, n_trials::Int)
    (; τ, ν, σ) = dist 
    return rand(LogNormal(ν, σ), n_trials) .+ τ
end

model = ShiftedLogNormal(ν = 1, σ = 1, τ = .20)

rts = rand(model, 100)
logpdf.(model, rts)
DominiqueMakowski commented 4 months ago

If you want a common SSM1D interface

Can we add that that 🥺 🙏? I can make a PR if you prefer :) It would be very convenient and add visibility to people looking for people looking for "shifted lognormal" (that might not be familiar with its nD equivalent as the LNR)

We can refer to this paper

itsdfish commented 4 months ago

Awesome. Thanks for the reference. I will add a new version of SSMs within the hour.

itsdfish commented 4 months ago

The new version should be available soon. Everything seems to be working:

using SequentialSamplingModels
using Turing 

n_samples = 50
rts = rand(ShiftedLogNormal(ν=-1, σ=.8, τ=.3), n_samples)

@model function model(rts; min_rt = minimum(rts))
    ν ~ Normal(-1, 2)
    σ ~ truncated(Normal(.8, 2), 0, Inf)
    τ ~ Uniform(0, min_rt)
    rts ~ ShiftedLogNormal(ν, σ, τ)
    return (;ν, σ, τ)
end

lb = [-1,0,0]
ub = [10, 10, minimum(rts)]

# Generate a MLE estimate.
mle_estimate = maximum_likelihood(model(rts); lb, ub)

# Generate a MAP estimate.
map_estimate = maximum_a_posteriori(model(rts); lb, ub)

# Sample with NUTS 
chain = sample(model(rts), NUTS(), 1000)
DominiqueMakowski commented 4 months ago

We could potentially add sections / groups for the list of models in the docs

image

To separate "RT-only models" (1-dimension) from "RT + Errors", from potential RT + Errors + other stuff (like in the attention models?)

itsdfish commented 4 months ago

I like that idea. What do you think of this variation:

1D

2D

ND (e.g., Cicular Diffusion, SSMs with confidence ratings, and generalized geometric models e.g., https://peterkvam.com/papers/kvam2019geometric.pdf)

DominiqueMakowski commented 4 months ago

(Just to note that ShiftedLognormal is currently not in the documentation I think)

DominiqueMakowski commented 4 months ago

Found an issue: predict() fails on shifted log normal models (iterate method not defined)?

Here's an MWE with a (working) Wald model followed by the failing LogNormal:


using Downloads, CSV, DataFrames
using Turing, Distributions, SequentialSamplingModels
using CairoMakie

df = CSV.read(Downloads.download("https://raw.githubusercontent.com/DominiqueMakowski/CognitiveModels/main/data/wagenmakers2008.csv"), DataFrame)
df = df[df.Error.==0, :]
df.Accuracy = df.Condition .== "Accuracy"

@model function model_wald(rt; min_rt=minimum(df.RT), condition=nothing)

    # Priors 
    σ ~ truncated(Normal(0, 0.5); lower=0)
    τ ~ truncated(Gamma(1.1, 11); upper=min_rt)

    intercept ~ truncated(Normal(0, 2); lower=0)
    slope_accuracy ~ Normal(0, 0.1)

    for i in 1:length(rt)
        μ = intercept + slope_accuracy * condition[i]
        rt[i] ~ Wald(μ, σ, τ)
    end
end

model = model_wald(df.RT; condition=df.Accuracy)
chain_wald = sample(model, NUTS(), 400)

pred = predict(model_wald([(missing) for i in 1:length(df.RT)]; condition=df.Accuracy), chain_wald)

@model function model_lognormal(rt; min_rt=minimum(df.RT), condition=nothing)

    # Priors 
    σ ~ truncated(Normal(0, 0.5); lower=0)
    τ ~ truncated(Gamma(1.1, 11); upper=min_rt)

    intercept ~ Normal(0, 2)
    slope_accuracy ~ Normal(0, 0.5)

    for i in 1:length(rt)
        μ = intercept + slope_accuracy * condition[i]
        rt[i] ~ ShiftedLogNormal(μ, σ, τ)
    end
end

model = model_lognormal(df.RT; condition=df.Accuracy)
chain_lognormal = sample(model, NUTS(), 400)

pred = predict(model_lognormal([(missing) for i in 1:length(df.RT)]; condition=df.Accuracy), chain_lognormal)
itsdfish commented 4 months ago

By the way, your Wald model crashed due to nu being out of range. You might try adding above your sampling statement the following:

  if condition here
      Turing.@addlogprob! -Inf
      # Exit the model evaluation early
      return
  end

I'm not quite sure why the shiftedlognormal is not working. I'll report back.

itsdfish commented 4 months ago

The error message was completely useless. The culprit was the undefined method of rand for 1 sample, which returns a Real. A patch version will be released soon.

DominiqueMakowski commented 4 months ago

By the way, your Wald model crashed due to nu being out of range. You might try adding above your sampling statement the

what does if condition here means ?

itsdfish commented 4 months ago

In that specific case, I think you need if μ < 0 .

itsdfish commented 4 months ago

It will record -Inf as the log density and exit the model before you reach an error. It might be better than forcing mu to be 0 because it will sample less from parameters below zero. If the log density is low at zero, it will not matter much.

DominiqueMakowski commented 4 months ago

Hum could you put that bit of code in context of a model I'm not sure where it should be added

DominiqueMakowski commented 4 months ago

A patch version will be released soon.

Not sure if 11.4 was supposed to fix it but predict() still fails with the same error I think:

julia> pred = predict(model_lognormal([(missing) for i in 1:length(df.RT)]; condition=df.Accuracy), chain_lognormal)
       pred = Array(pred)
ERROR: MethodError: no method matching iterate(::ShiftedLogNormal{Float64})

Closest candidates are:
  iterate(::StatsBase.CoefTable)
   @ StatsBase C:\Users\domma\.julia\packages\StatsBase\ebrT3\src\statmodels.jl:42
  iterate(::StatsBase.CoefTable, ::Integer)
   @ StatsBase C:\Users\domma\.julia\packages\StatsBase\ebrT3\src\statmodels.jl:42
  iterate(::Base.AsyncGenerator, ::Base.AsyncGeneratorState)
   @ Base asyncmap.jl:362
  ...
itsdfish commented 4 months ago

Hmmm. Can you restart your Julia session and retry. I restarted and it continued to work on my system. Here are the dependencies in my environment:

(turing_predict) pkg> st
Status `~/.julia/dev/sandbox/turing_predict/Project.toml`
  [336ed68f] CSV v0.10.14
  [a93c6f00] DataFrames v1.6.1
  [f6369f11] ForwardDiff v0.10.36
  [0e71a2a6] SequentialSamplingModels v0.11.4
  [2913bbd2] StatsBase v0.34.3
  [fce5fe82] Turing v0.33.1
  [f43a241f] Downloads v1.6.0
DominiqueMakowski commented 4 months ago

false alert, all good it works now :) thanks again!

Just one more thing: in the docs it says that sigma is on the log scale, but is it really? If it was on the log scale it would take negative values as well right (and to recover the sigma in an un-logged space we would do an exp() transform?). Or am I missing something?

Finally, what do you think of exp-transforming sigma like that to avoid non-zero values and failures due to negative slopes?

@model function model_lognormal2(rt; min_rt=minimum(df.RT), condition=nothing)

    # Priors 
    τ ~ truncated(Gamma(1.1, 11); upper=min_rt)

    μ_intercept ~ Normal(0, 2)
    μ_condition ~ Normal(0, 0.5)

    σ_intercept ~ -Weibull(2.5, 3) + 1
    σ_condition ~ Normal(0, 0.01)

    for i in 1:length(rt)
        μ = μ_intercept + μ_condition * condition[i]
        σ = σ_intercept + σ_condition * condition[i]
        rt[i] ~ ShiftedLogNormal(μ, exp(σ), τ)
    end
end

It seems to work, the only thing seems to be how to correctly specify prior on sigma's intercept. I thought of a distribution that slowly rises from -Inf and then sharply decreases to 0 at ~1... And came up with this "reverse" Weibull which maximizes a plausible range of sigma values (I think?)

xaxis = range(-6, 2, 1000)
fig = Figure()
ax = Axis(fig[1, 1])
# lines!(xaxis, pdf.(-Weibull(2, 2.5)+1, xaxis); color="red")
# lines!(xaxis, pdf.(-Weibull(2.5, 2)+1, xaxis); color="orange")
# lines!(xaxis, pdf.(-Weibull(2.5, 2.5)+1, xaxis); color="blue")
lines!(exp.(xaxis), pdf.(-Weibull(2.5, 3)+1, xaxis); color="green")
fig

image

itsdfish commented 4 months ago

No problem. I'm glad it was a false alarm.

My understanding is that log space (rather than log scale) is the co-domain of a random variable after a logarithmic transformation:

X ∼ lognormal(μ, σ)
Y ~ log(X)
mean(y) = μ
std(y) = σ

Here is a numerical example:

julia> x = log.(rand(LogNormal(-1, .5), 100_000)); mean(x), std(x)
(-1.0016541123840248, 0.49895896017964375)

I'm not a mathematician, but I think that is correct. We could always replace log space with "after a logarithmic transformation".

Sorry. I overlooked your question about @addlogprob. I think that is the recommended approach. The advantage is that your prior and posterior distributions would retain their original interpretations. It should go in your for loop right after you specify mu:

μ = intercept + slope_accuracy * condition[i]
if μ < 0 
    Turing.@addlogprob! -Inf
    return nothing
end

This will make the log likelihood -Inf, pushing the sample out of that region and allowing you to escape early. I'm not possitive whether exponetiating sigma will work or not. I recommend asking more mathematically saavy Turing devs if the @addlogprob! trick above doesn't work.