joshday / OnlineStats.jl

⚡ Single-pass algorithms for statistics
https://joshday.github.io/OnlineStats.jl/latest/
MIT License
831 stars 62 forks source link

Online Dirichlet process univariate Gaussian mixture model #247

Closed Red-Portal closed 2 years ago

Red-Portal commented 2 years ago

Hi, this PR contributes the online Dirichlet process mixture model algorithm in:

Dahua Lin, "Online Learning of Nonparametric Mixture Models via Sequential Variational Approximation." Advances in Neural Information Processing Systems 26, 2013.

I expect this would be useful for some people given that OnlineStats doesn't currently have Gaussian mixture models, which are fine in the online setting, but their sensitivity to the initialization could be a concern. A nice thing about DPMMs is that they are less sensitive to the initialization (although they are sensitive to the hyperparameters instead, but I think that's better than having to run K-Means for the initial points). Despite its age, the algorithm is still somewhat state-of-the-art as far as I've seen. I've implemented the univariate Gaussian mixture model variant of this algorithm with an unknown mean and unknown variance:

G ~ DP(dirchlet_alpha, Normal-Gamma)
(μₖ, τₖ) ~ G
x ~ N(μ, 1/sqrt(τ))

where the base measure is defined as

τₖ ~ Gamma(comp_alpha, 1/comp_beta)
μₖ ~ N(comp_mu, 1/sqrt(comp_lambda*τₖ)).

The variational distribution for each component is the Normal-Gamma mean-field family defined as

q(μₖ | τₖ; mₖ, lₖ) q(τₖ; aₖ, bₖ) = N(μₖ; mₖ, 1/(lₖ*τₖ)) Gamma(τₖ; aₖ, 1/bₖ).

Here is a minimal working example:

using OnlineStats
using ProgressMeter
using Plots, StatsPlots
using Distributions

function main()
    n    = 1024
    μ    = 0.0
    λ    = 1e-3
    α    = 1.1
    β    = 1e-3
    α_dp = 1.0
    o    = DPMM(μ, λ, α, β, α_dp; comp_birth_thres=0.5, comp_death_thres=1e-2, n_comp_max=10)   
    p    = MixtureModel([Normal(-2.0, 0.5), Normal(3.0, 1.0), Normal(0.0, 0.2)], [0.4, 0.4, 0.2])
    x_dom = -5:0.01:5

    # anim = @animate for i in 1:n
    #     x = rand(p)
    #     fit!(o, x)
    #     q = value(o)
    #     Plots.plot( x_dom, x -> pdf(q, x), label="Variational")
    #     Plots.plot!(x_dom, x -> pdf(p, x), label="Target")
    # end
    # gif(anim, "dpmm.gif", fps=30)

    @showprogress for i in 1:n
        x = rand(p)
    fit!(o, x)
        q = value(o)
        display(Plots.plot( x_dom, x -> pdf(q, x), label="Variational"))
        display(Plots.plot!(x_dom, x -> pdf(p, x), label="Target"))
    end
end

and here is an animation generated using InteractiveUtils.jl: dpmm

The implementation is pretty much complete at the moment, but I would like to receive some feedback for completeness. Here are some potential concerns:

Prior Elicitation for the Hyperparameters

Below is a basic snippet for automatically setting the hyperparameters.

    using Roots

    prob_τₖ = 0.8
    τₖ_max  = 0.5
    μ₀      = 0.0
    α₀      = 2.1
    prob_μₖ = 0.8
    μₖ_min  = -2
    μₖ_max  = 2

    β₀ = find_zero(β₀ -> cdf(InverseGamma(α₀, β₀), τₖ_max) - prob_τₖ, (1e-4, Inf))
    λ₀ = find_zero(λ₀ -> begin
        p_μ = TDist(2*α₀)*sqrt(β₀/(λ₀*α₀)) + μ₀
        cdf(p_μ, μₖ_max) - cdf(p_μ, μₖ_min) - prob_τₖ
    end, (1e-4, 1e+2))

It sets the quantiles of the gamma prior on τₖ and μₖ such that it covers 80% of their expected range.

joshday commented 2 years ago

I haven't looked over this yet, but my first thought is that I have previously made the decision to not allow Distributions.jl to be a dependency because of how heavy it is. Due to that decision, the OnlineStats in src/stats/distributions.jl are a little awkward to use.

I think in a perfect world we would have the src/stats/distributions.jl stuff as well as this PR in its own OnlineDistributions.jl (or similar) package, one that can lean into the Distributions.jl dependency and not have to find ways to work around it.

Red-Portal commented 2 years ago

@joshday Hi thanks for the quick response. The implementation itself does not rely on Distributions.jl in any meaningful way, it's just used as an interface to output the learned mixture model, so I think your concern can be dealt with. But then, we'll have to think about how the user is supposed to interact with the mixture model.

Since src/stats/distributions.jl works fine with returning the distribution parameters, I guess DPMM could be fine just doing that.

Red-Portal commented 2 years ago

Just curious though, while I do agree that Distributions.jl is a big, big package, is there a practical reason to avoid it? Although splitting to OnlineDistributions.jl could be a solution, I don't see why OnlineStats.jl shouldn't grow as a big package itself.

joshday commented 2 years ago

while I do agree that Distributions.jl is a big, big package, is there a practical reason to avoid it?

Off the top of my head:

  1. It slows down testing and ci.

  2. Increases the likelihood of an upper-bounded compat entry somewhere in the dependency stack.

    2a. I've had a lot of headaches because of upper bounds in packages with lots of dependencies. My past experience is certainly influencing my view of avoiding strict dependencies where necessary.

    2b. I suspect that most of OnlineStats users are just using Mean, Variance, the histogram types, etc. (the simple stuff). I don't want to force unnecessary deps on them.

  3. There's really not anything I want to use from Distributions apart from creating the distribution types.

  4. I'm risk averse. If there's a chance a change will increase the time required to maintain OnlineStats, I kinda need to avoid it. My bandwidth is very limited at the moment.

Red-Portal commented 2 years ago

That makes sense. Thanks for taking the time for an explanation. Let me know if you have anything to say about the PR.

Red-Portal commented 2 years ago

Added test, tidied code, and removed the Distributions.jl dependency. The PR description above has been updated accordingly.

joshday commented 2 years ago

Just now circling back to this PR. I think I've talked myself into allowing a dependency on Distributions.

Sorry for the added work, but I can either merge as-is or I'll merge once you've added it back in.

Red-Portal commented 2 years ago

Hi Josh, funny because you convinced me to not mess with Distributions.jl haha. Are you sure about adding the dependency? If so, I'll add the stuff back in.

joshday commented 2 years ago

I know, I know. Sorry!

Yes I'm sure. Folks who want the lightest-weight option can use OnlineStatsBase.

SpecialFunctions and StatsFuns are also heavy deps in the sense they add binaries, but they're also so essential to so many packages in Julia I'm less concerned about version bound troubles.

It would make a few things easier to code/maintain if we could use the Distribution types directly.

Red-Portal commented 2 years ago

Added Distributions.jl back, added an API for resetting the hyperparameters, and improved documentation. Pretty much done on my side!

joshday commented 2 years ago

Running CI again. I'll merge when I see green unless you have anything else to add!

Red-Portal commented 2 years ago

I'm not really used to Julia's documentation system, so please let me know if there's anything to improve on that side.