JuliaPlots / StatsPlots.jl

Statistical plotting recipes for Plots.jl
Other
437 stars 88 forks source link

Shouldn't the PDFs of `UnivariateGMM` (and possibly other mixtures) be "squashed" according to their weights? #458

Open ForceBru opened 3 years ago

ForceBru commented 3 years ago

TL;DR: currently PDFs of the components seem to be plotted without considering their weights (that the code calls "prior"). Maybe it would be better to plot them weighted, or maybe plot the PDF of the mixture (instead of plotting PDFs of its components) with an option to add PDFs of the weighted components to the plot?


The pull request https://github.com/JuliaPlots/StatsPlots.jl/pull/456 introduced plotting of UnivariateGMM, so I was trying to use it to see how well my model describes the data. However, the fit looked terrible (the PDFs of the components were way too spiky) despite convergence criteria and the log-likelihood showing that the fit wasn't that bad. I plotted the data from the pull request's sample code and found that the PDFs of the components are not weighted:

using StatsPlots
using Distributions

distr = UnivariateGMM(
    [0, 4], [1, 2], Categorical([0.2, 0.8])
)

data = rand(distr, 10_000)

plt = histogram(data, bins=300, normalize=true, linewidth=0)
plot!(plt, distr)
savefig(plt, "stats.png")

stats

I think the PDF on the left doesn't look anything like the leftmost "bump" of the histogram, so it looks like the mixture model fits the data poorly, yet we know that the data were literally sampled from this exact mixture. I would expect the plots of the two PDFs to sort of "hug" the histogram, like this:

x = range(minimum(data), maximum(data), length=500)
μs, σs, p_distr = params(distr)
plt2 = histogram(data, bins=300, normalize=true, linewidth=0)
for (μ, σ, p) ∈ zip(μs, σs, probs(p_distr))
    # Weight the PDF by multiplying by `p`
    the_pdf = @. p * exp(-1/2 * (x - μ)^2/σ^2) / (sqrt(2π) * σ)
    plot!(plt2, x, the_pdf, label="μ = $μ", linewidth=3)
end
savefig(plt2, "stats2.png")

stats2

Of course, there's overlap in the middle that individual components can't explain that well, but the full PDF of the mixture can:

x = range(minimum(data), maximum(data), length=500)
full_pdf = zeros(size(x))
μs, σs, p_distr = params(distr)
for (μ, σ, p) ∈ zip(μs, σs, probs(p_distr))
    @. full_pdf += p * exp(-1/2 * (x - μ)^2/σ^2) / (sqrt(2π) * σ)
end
plt3 = histogram(data, bins=300, normalize=true, linewidth=0)
plot!(plt3, x, full_pdf, label="Mixture PDF", linewidth=3)
savefig(plt3, "stats3.png")

stats3


So, maybe plot the PDF of the mixture and, optionally, the weighted components' PDFs?

sethaxen commented 3 years ago

Personally, I don't like the default to plotting the individual components of the mixture model either for exactly the reasons you've given. But I think we would need to look into why those decisions were made before changing the default.

sethaxen commented 3 years ago

The recipe for MixtureModel was added in #246. Then or since, it seems there has been no discussion about the components default argument. I agree the default should be components=false. @mkborregaard or @BeastyBlacksmith if we do fix this, would that be considered a breaking change?

BeastyBlacksmith commented 3 years ago

There is room for interpretation here, but I'd say that changing defaults is a breaking change. But since StatsPlots is pre 1.0 making these changes is fine. I'd think about other breaking changes you might want to do and do them in one batch though.

sethaxen commented 2 years ago

Revisiting this, I think the current behavior of plotting all components separately with their own styles makes sense if one has a mixture of discrete and continuous distributions (currently not allowed by Distributions but will be in the future, and censored distributions are examples of this). So e.g. one could plot a Censored(Normal(), -1, 1) with lines within the interval and sticks at the bounds. But then if one constructs a mixture whose components have discrete components, one can't programmatically detect this.

One thing we could do is grab the default_ranges for all mixture components, augment the ranges from discrete distributions with nextfloat and prevfloat, then interleave all points, remove duplicates, and evaluate the mixture density. This would show discrete atoms as vertical lines and due to recursion would work for mixtures of mixtures.