dmetivie / ExpectationMaximization.jl

A simple but generic implementation of Expectation Maximization algorithms to fit mixture models.
https://dmetivie.github.io/ExpectationMaximization.jl/
MIT License
33 stars 1 forks source link

Handling dropouts #11

Open timholy opened 10 months ago

timholy commented 10 months ago

In cases of poor initialization, some components of the mixture may drop out. For example, let's create a 2-component mixture that is very poorly initialized:

julia> X = randn(10);

julia> mix = MixtureModel([Normal(100, 0.001), Normal(200, 0.001)], [0.5, 0.5]);

julia> logpdf.(components(mix), X')
2×10 Matrix{Float64}:
 -4.92479e9   -4.97741e9   -5.02964e9   -5.15501e9   -5.05792e9   …  -5.16391e9   -4.88617e9   -4.93348e9   -5.09162e9
 -1.98493e10  -1.99548e10  -2.00592e10  -2.03088e10  -2.01157e10     -2.03265e10  -1.97717e10  -1.98667e10  -2.01828e10

You can see that both have poor likelihood, but one of the two always loses by a very large margin. Then when we go to optimize,

julia> fit_mle(mix, X)
ERROR: DomainError with NaN:
Normal: the condition σ >= zero(σ) is not satisfied.
Stacktrace:
  [1] #371
    @ ~/.julia/dev/Distributions/src/univariate/continuous/normal.jl:37 [inlined]
  [2] check_args
    @ ~/.julia/dev/Distributions/src/utils.jl:89 [inlined]
  [3] #Normal#370
    @ ~/.julia/dev/Distributions/src/univariate/continuous/normal.jl:37 [inlined]
  [4] Normal
    @ ~/.julia/dev/Distributions/src/univariate/continuous/normal.jl:36 [inlined]
  [5] fit_mle
    @ ~/.julia/dev/Distributions/src/univariate/continuous/normal.jl:229 [inlined]
  [6] fit_mle(::Type{Normal{Float64}}, x::Vector{Float64}, w::Vector{Float64}; mu::Float64, sigma::Float64)
    @ Distributions ~/.julia/dev/Distributions/src/univariate/continuous/normal.jl:256
  [7] fit_mle
    @ ~/.julia/dev/Distributions/src/univariate/continuous/normal.jl:253 [inlined]
  [8] fit_mle
    @ ~/.julia/dev/ExpectationMaximization/src/that_should_be_in_Distributions.jl:17 [inlined]
  [9] (::ExpectationMaximization.var"#2#3"{Vector{Normal{Float64}}, Vector{Float64}, Matrix{Float64}})(k::Int64)
    @ ExpectationMaximization ./none:0
 [10] iterate(::Base.Generator{Vector{Any}, DualNumbers.var"#1#3"})
    @ Base ./generator.jl:47 [inlined]
 [11] collect_to!(dest::AbstractArray{T}, itr::Any, offs::Any, st::Any) where T
    @ Base ./array.jl:890 [inlined]
 [12] collect_to_with_first!(dest::AbstractArray, v1::Any, itr::Any, st::Any)
    @ Base ./array.jl:868 [inlined]
 [13] collect(itr::Base.Generator{UnitRange{Int64}, ExpectationMaximization.var"#2#3"{Vector{…}, Vector{…}, Matrix{…}}})
    @ Base ./array.jl:842
 [14] fit_mle!(α::Vector{…}, dists::Vector{…}, y::Vector{…}, method::ClassicEM; display::Symbol, maxiter::Int64, atol::Float64, robust::Bool)
    @ ExpectationMaximization ~/.julia/dev/ExpectationMaximization/src/classic_em.jl:48
 [15] fit_mle!
    @ ~/.julia/dev/ExpectationMaximization/src/classic_em.jl:14 [inlined]
 [16] fit_mle(::MixtureModel{…}, ::Vector{…}; method::ClassicEM, display::Symbol, maxiter::Int64, atol::Float64, robust::Bool,
 infos::Bool)
    @ ExpectationMaximization ~/.julia/dev/ExpectationMaximization/src/fit_em.jl:30
 [17] fit_mle(::MixtureModel{Univariate, Continuous, Normal{Float64}, Categorical{Float64, Vector{Float64}}}, ::Vector{Float64})
    @ ExpectationMaximization ~/.julia/dev/ExpectationMaximization/src/fit_em.jl:12
 [18] top-level scope
    @ REPL[8]:1
Some type information was truncated. Use `show(err)` to see complete types.

This arises because α[:] = mean(γ, dims = 1) returns α = [1.0, 0.0]. In other words, component 2 of the mixture "drops out."

I've found errors like these, as well as positive-definiteness errors in a multivariate context, to be pretty ubiquitous when fitting complicated distributions and point-clouds. To me it seems we'd need to come up with some kind of guard against this behavior? But I'm not sure what the state-of-the-art approach is, or I'd implement it.

dmetivie commented 10 months ago

Yes I noticed that also. The robust = true keyword kind of prevent some of these behavior but does not catch everything at all.

I think in some sense this is really inherent to the EM algo, if it starts near a local minimal that has a droupout component it will go toward it, until numerical precision return an error. I don't think there is much we can do, aside from implementing a different version of EM that escape these holes.

That said, maybe something like LogarithmicNumbers.jl or for the exponential familly ExponentialFamily.jl could help ?

For practice, I also added this fit_mle to test over multiple initial condition and return the best fitted model and avoid errors with try and catch.

timholy commented 10 months ago

If returning "empty" components is OK, one easy option might be simply to add N*α[i] < thresh && continue so that components assigned fewer than thresh points just don't get updated. One could make thresh = 1 perhaps by default, but there would also be arguments for either thresh = 1e-6 or thresh = d^2/2 + d + 1 (the latter basically saying we want enough data to determine the amplitude, mean, and covariance matrix).

timholy commented 10 months ago

To get a sense of how common this is, I wrote a quick script to generate random test cases and then report back cases that exhibited various classes of errors:

using ExpectationMaximization
using Distributions
using Random

nwanted = 3
nmax = 10000

# For DomainError
domerrX = Matrix{Float64}[]
domerridxs = Vector{Int}[]   # indices of the centers in corresponding X

# For posdef errors
pderrX = Matrix{Float64}[]
pderridxs = Vector{Int}[]

function init_mixture(X, centeridxs)
    dist = [MvNormal(X[:, idx], 1) for idx in centeridxs]
    αs = ones(length(centeridxs)) / length(centeridxs)
    return MixtureModel(dist, αs)
end

for i = 1:nmax
    (length(domerrX) >= nwanted && length(pderrX) >= nwanted) && (@show i; break)
    ctrue = [randn(2) for _ = 1:3]
    X = reduce(hcat, [randn(length(c), 20) .+ c for c in ctrue])
    X = round.(X; digits=2)    # to make it easy to write to a text file
    startidx = randperm(60)[1:3]
    mix = init_mixture(X, startidx)
    try
        fit_mle(mix, X)
    catch err
        isa(err, InterruptException) && rethrow(err)
        if isa(err, DomainError)
            if length(domerrX) < nwanted
                push!(domerrX, X)
                push!(domerridxs, startidx)
            end
        else
            if length(pderrX) < nwanted
                push!(pderrX, X)
                push!(pderridxs, startidx)
            end
        end
    end
end

This didn't generate any of the positive-definite errors I've seen in different circumstances (maybe that requires higher dimensionality?), but somewhere between 5-10% of all cases resulted in a dropout. There doesn't appear to be anything particularly bizarre about them; here's a typical case:

image

The red dots are both data points and the starting positions of the clusters. If there's a pattern, it seems that at least one of the red dots should be fairly near the cluster edge.

timholy commented 10 months ago

So, what ends up happening is that Σ → 0 because only a single point gets associated with a component. The existing robust=true fails to catch this because it results in NaN rather than Inf because exp(-mahalanobis^2)/sqrt(det(Σ)) → 0/0. It's likely that some kind of shrinkage might be the best solution, but I pushed a bandaid in #12.