ElOceanografo / MarginalLogDensities.jl

Marginalized log-probability functions in Julia
MIT License
2 stars 3 forks source link

dimension mismatch in `_marginalize` #9

Closed slwu89 closed 1 year ago

slwu89 commented 2 years ago

Hi @ElOceanografo, thanks for this very nice (albeit in-progress) package!

I am trying to implement the "urchins" example from Simon Wood's "Core Statistics" to test the package on a somewhat less typical example (see here for the example in R).

I've implemented it as below. When calling the callable MarginalLogDensity with both the random and fixed effects as input, it returns the log likelihood correctly. However, when calling it with only the fixed effects to marginalize over the random effects I get an error about dimension mismatch. I tracked it down to this particular line in the source code https://github.com/ElOceanografo/MarginalLogDensities.jl/blob/master/src/MarginalLogDensities.jl#L241 but I'm not sure where in my code I've implemented things incorrectly to get such an error. If you are able to take a look, I'd greatly appreciate it.

using Plots
using MarginalLogDensities
using Distributions
using Random
using Optim

function urchin_V(ω, p, g, a)
    ω = exp(ω)
    p = exp(p)
    g = exp(g)
    am = log(p/(g*ω))/g
    if a < am
        return ω*exp(g*a)
    else
        return p/g + p*(a-am)
    end
end

# number of data
const n = 100

# true parameters

# fixed effects
θ_true = zeros(206)
θ_true[1] = -4.0
θ_true[2] = -0.5
θ_true[3] = log(0.2)
θ_true[4] = 1
θ_true[5] = log(0.1)
θ_true[6] = log(0.1)

# random effects
θ_true[7:106] = rand(Normal(θ_true[2], exp(θ_true[3])), n) #g
θ_true[107:206] = rand(Normal(θ_true[4], exp(θ_true[5])), n) #p

# simulate true data

# urchin ages
a = Float64.(sample(1:30, n))

# urchin volumes
v = urchin_V.(repeat([θ_true[1]],n), θ_true[107:206], θ_true[7:106], a)
v_samp = rand.(Normal.(sqrt.(v), exp(θ_true[6]))) # data + measurement error for likelihood computation

# plot the "data"
scatter(a,v,legend=false)
scatter(a,v_samp,legend=false)

# likelihood function for model with fixed and random effects
function loglik_urchin(θ::Vector{T}) where {T<:Real}

    # parameters
    log_ω = θ[1]
    log_g = θ[7:106]
    log_p = θ[107:206]
    μ_g = θ[2]
    log_σ_g = θ[3]
    μ_p = θ[4]
    log_σ_p = θ[5]
    log_σ = θ[6]

    # estimated volumes of urchins conditional on current θ
    v_est = urchin_V.(repeat([log_ω],n), log_p, log_g, a)

    # data likelihood
    data_lik = sum(logpdf.(Normal.(sqrt.(v_est), exp(log_σ)), v_samp))

    # random effect (growth rates) likelihood
    g_lik = loglikelihood(Normal(μ_g, exp(log_σ_g)), log_g)
    p_lik = loglikelihood(Normal(μ_p, exp(log_σ_p)), log_p)

    return data_lik + g_lik + p_lik
end

loglik_urchin(θ_true)

n_θ = length(θ_true)
marginal_ix = collect(7:206)
marginalloglik_urchin = MarginalLogDensity(loglik_urchin, n_θ, marginal_ix)

# check it works at the true values
marginalloglik_urchin(θ_true[7:end], θ_true[1:6])

# just fixed effects
marginalloglik_urchin(θ_true[1:6])
slwu89 commented 2 years ago

Doing a bit of digging. It seems that the issue is that the Hessian sparsity pattern isn't being picked up in the right way. See the results of looking at Hsparsity:

julia> findnz(marginalloglik_urchin.hessconfig.Hsparsity)
(Int64[], Int64[], Bool[])
ElOceanografo commented 2 years ago

Thanks for this issue, and apologies for the slow response. I'm actually working on a rewrite of this package after moving some of the Hessian sparsity machinery into SparseDiffTools (https://github.com/JuliaDiff/SparseDiffTools.jl/pull/190). Are you trying to use this currently (i.e., would a bugfix now be helpful)?

slwu89 commented 2 years ago

Hi @ElOceanografo, no problem! I'm glad to see you are still working on it.

Yes I am trying to use this currently. I've spent some more time thinking about the nature of the problem here, does the Hessian sparsity detection in this package rely on Symbolics.jl to get the sparsity pattern? I noticed because of the comparison between two floats in the model code which means Symbolics won't be able to work with it (https://symbolics.juliasymbolics.org/dev/manual/faq/#Transforming-my-function-to-a-symbolic-equation-has-failed.-What-do-I-do?-1). Happy to help try to fix things if you can point me in the right direction too! I'm still new to Julia but trying to grasp how to use things.

ElOceanografo commented 2 years ago

The package currently uses SparsityDetection.jl by default, with an option to use ForwardDiff.jl as a backup (i.e., calculating the Hessian using ForwardDiff and then looking at which elements are nonzero). ForwardDiff is a more robust, but will be less efficient (or maybe not feasible at all) if the Hessian is big enough.

I just ran your code, and it looks like SparsityDetection is not able to infer the structure (maybe because of that conditional in urchin_V?), but ForwardDiff is. Just change the line defining the structure to

marginalloglik_urchin = MarginalLogDensity(loglik_urchin, n_θ, marginal_ix, LaplaceApprox(), true)

and check the Hessian pattern. After that, the marginalized call runs, though it produces an infinite log-likelihood, so there's still something going wrong. Will try to look into it more.

slwu89 commented 2 years ago

Thanks @ElOceanografo. I'm trying to rewrite my code to be more AD-friendly as well, I will update soon.

Interesting, and SparsityDetection.jl seems to be deprecated now. Maybe figuring out current best practices/packages to use is something I can help look into.

ElOceanografo commented 2 years ago

Yes, it's been superseded by Symbolics.jl, but I haven't converted this package over yet. In my (limited) experience there are still cases where Symbolics has trouble figuring out the sparsity pattern of Julia functions; if you come across those it would definitely be good to file issues over there. And if you wanted to dig into the current best ways for this package to detect sparsity in arbitrary Julia functions, that would be really helpful!

slwu89 commented 2 years ago

Hi @ElOceanografo I'll try to look at the sparsity issue next week, but I posted a reference implementation using ForwardDiff with no attempt to exploit sparsity here: https://gist.github.com/slwu89/136ffe03913883acace3e5378fafce89 It recovers known correct values from a reference implementation in R.

I checked the intermediate results within _marginalize (using sparsity detection via ForwardDiff), there is something funky going on in how the sparse Hessian is being computed, the matrix below has the diagonal of H0 computed via sparse_hessian on the right versus plain old ForwardDiff.hessian on the left. I'll try to look into this as I get time.

julia> hcat(diag(ForwardDiff.hessian(f, θmarginal0)), collect(diag(H0)))
284×2 Matrix{Float64}:
 103.474   103.474
  93.3562   93.3562
  91.8251   91.8251
  91.356    91.356
  93.3216   93.3216
  92.501    92.501
  92.501    92.501
  93.5369   93.5369
  94.3318   94.3318
  94.6116   94.6116
  96.5217   96.5217
  93.7606   93.7606
  94.6567   94.6567
   ⋮       
 175.743     2.48333e-307
 175.201   NaN
 180.32      1.65781e-316
 178.602     2.18272e-314
 178.526     4.24399e-314
 177.182   NaN
 173.743     1.6976e-313
 184.349     4.0e-323
 188.217   NaN
 183.786     0.0
 194.774     0.0
 188.571     4.24399e-314
 203.49      0.0
ElOceanografo commented 2 years ago

Hm, interesting. That would explain it. I would not be surprised if there are mistakes or inaccuracies in the sparse Hessian decompression code currently implemented here...I'll try to finish up converting to use the newer version in SparseDiffTools (https://github.com/ElOceanografo/MarginalLogDensities.jl/issues/10) in the next few days and see if that fixes the problem.

ElOceanografo commented 1 year ago

I just modified the second half of your script as follows using the "sparsediff" branch, and it runs. Once #12 is merged, I think we can close this one...

function loglik_urchin(θ::Vector{T}, data) where {T<:Real}

    # parameters
    log_ω = θ[1]
    log_g = θ[7:106]
    log_p = θ[107:206]
    μ_g = θ[2]
    log_σ_g = θ[3]
    μ_p = θ[4]
    log_σ_p = θ[5]
    log_σ = θ[6]

    # estimated volumes of urchins conditional on current θ
    v_est = urchin_V.(repeat([log_ω],n), log_p, log_g, data.a)

    # data likelihood
    data_lik = sum(logpdf.(Normal.(sqrt.(v_est), exp(log_σ)), data.v_samp))

    # random effect (growth rates) likelihood
    g_lik = loglikelihood(Normal(μ_g, exp(log_σ_g)), log_g)
    p_lik = loglikelihood(Normal(μ_p, exp(log_σ_p)), log_p)

    return data_lik + g_lik + p_lik
end

data = (;a, v_samp)
loglik_urchin(θ_true, data)

n_θ = length(θ_true)
marginal_ix = collect(7:206)
marginalloglik_urchin = MarginalLogDensity(loglik_urchin, ones(n_θ), marginal_ix, data,
    LaplaceApprox(adtype=MarginalLogDensities.Optimization.AutoReverseDiff()), 
    hess_autosparse=:forwarddiff)

# just fixed effects
using BenchmarkTools
@btime marginalloglik_urchin($θ_true[1:6], $data)
fit = optimize(marginalloglik_urchin, zeros(6), data, Newton(), 
    Optim.Options(show_trace=true, f_tol=1e-6))
scatter(θ_true[1:6], label="true", xlabel="Parameter", ylabel="Value")
scatter!(fit.minimizer, label="fitted")
ElOceanografo commented 1 year ago

Fixed in #12.