Open timholy opened 10 months ago
Yes I noticed that also.
The robust = true
keyword kind of prevent some of these behavior but does not catch everything at all.
I think in some sense this is really inherent to the EM algo, if it starts near a local minimal that has a droupout component it will go toward it, until numerical precision return an error. I don't think there is much we can do, aside from implementing a different version of EM that escape these holes.
That said, maybe something like LogarithmicNumbers.jl or for the exponential familly ExponentialFamily.jl could help ?
For practice, I also added this fit_mle to test over multiple initial condition and return the best fitted model and avoid errors with try
and catch
.
If returning "empty" components is OK, one easy option might be simply to add N*α[i] < thresh && continue
so that components assigned fewer than thresh
points just don't get updated. One could make thresh = 1
perhaps by default, but there would also be arguments for either thresh = 1e-6
or thresh = d^2/2 + d + 1
(the latter basically saying we want enough data to determine the amplitude, mean, and covariance matrix).
To get a sense of how common this is, I wrote a quick script to generate random test cases and then report back cases that exhibited various classes of errors:
using ExpectationMaximization
using Distributions
using Random
nwanted = 3
nmax = 10000
# For DomainError
domerrX = Matrix{Float64}[]
domerridxs = Vector{Int}[] # indices of the centers in corresponding X
# For posdef errors
pderrX = Matrix{Float64}[]
pderridxs = Vector{Int}[]
function init_mixture(X, centeridxs)
dist = [MvNormal(X[:, idx], 1) for idx in centeridxs]
αs = ones(length(centeridxs)) / length(centeridxs)
return MixtureModel(dist, αs)
end
for i = 1:nmax
(length(domerrX) >= nwanted && length(pderrX) >= nwanted) && (@show i; break)
ctrue = [randn(2) for _ = 1:3]
X = reduce(hcat, [randn(length(c), 20) .+ c for c in ctrue])
X = round.(X; digits=2) # to make it easy to write to a text file
startidx = randperm(60)[1:3]
mix = init_mixture(X, startidx)
try
fit_mle(mix, X)
catch err
isa(err, InterruptException) && rethrow(err)
if isa(err, DomainError)
if length(domerrX) < nwanted
push!(domerrX, X)
push!(domerridxs, startidx)
end
else
if length(pderrX) < nwanted
push!(pderrX, X)
push!(pderridxs, startidx)
end
end
end
end
This didn't generate any of the positive-definite errors I've seen in different circumstances (maybe that requires higher dimensionality?), but somewhere between 5-10% of all cases resulted in a dropout. There doesn't appear to be anything particularly bizarre about them; here's a typical case:
The red dots are both data points and the starting positions of the clusters. If there's a pattern, it seems that at least one of the red dots should be fairly near the cluster edge.
So, what ends up happening is that Σ → 0
because only a single point gets associated with a component. The existing robust=true
fails to catch this because it results in NaN
rather than Inf
because exp(-mahalanobis^2)/sqrt(det(Σ)) → 0/0
. It's likely that some kind of shrinkage might be the best solution, but I pushed a bandaid in #12.
In cases of poor initialization, some components of the mixture may drop out. For example, let's create a 2-component mixture that is very poorly initialized:
You can see that both have poor likelihood, but one of the two always loses by a very large margin. Then when we go to optimize,
This arises because
α[:] = mean(γ, dims = 1)
returnsα = [1.0, 0.0]
. In other words, component 2 of the mixture "drops out."I've found errors like these, as well as positive-definiteness errors in a multivariate context, to be pretty ubiquitous when fitting complicated distributions and point-clouds. To me it seems we'd need to come up with some kind of guard against this behavior? But I'm not sure what the state-of-the-art approach is, or I'd implement it.