JuliaStats / Distributions.jl

A Julia package for probability distributions and associated functions.
Other
1.12k stars 418 forks source link

Discussion: Extended MLE for exponential family distributions #110

Closed lindahua closed 11 years ago

lindahua commented 11 years ago

It is well known that MLE of exponential family distributions (many useful distributions are in this category) consists of two steps:

  1. compute sufficient statistics
  2. transform sufficient statistics to distribution parameters

In the practice of machine learning, we often have to estimate model on weighted data (e.g. in estimating mixture model or anything that has the problem of data association), or even directly estimate the model based on sufficient statistics computed in different ways (e.g. variational inference).

To support all such use, we may consider refactor the MLE codes and create a more versatile API as follows:

# Let D be a distribution type

s = suffstats(D, data)        # sufficient statistics on data
s = suffstats(D, data, w)    # sufficient statistics on weighted data

fit_mle_on_stats(D, s)   # construct a model from pre-computed sufficient statistics

fit_mle(D, data) = fit_mle_on_stats(suffstats(D, data))
fit_mle(D, data, w) = fit_mle_on_stats(suffstats(D, data, w))

# the two functions above can be in fallbacks, so that 
# for each distribution, 
# we only have to implement suffstats and fit_mle_on_stats

I can do this (it is not difficult and would not take much time). But I would like to see how you think about this in advance.

simonbyrne commented 11 years ago

It certainly seems like a good idea. A while ago I proposed creating a WeightedVector type either in Base or Stats.jl (see here, which perhaps could still be worth it, given the number of statistical functions that use it.

You could also create a SufficientStatistic type, but that could be tricky to implement (would it be a parametric type?).

lindahua commented 11 years ago

The vision here is not just to deal with weighted samples, but also the cases where the sufficient statistics are computed using other means (e.g. in the variational inference algorithm of LDA, the sufficient statistics are computed in a special way instead of directly computing from (weighted) samples).

What I am thinking here is to let each distribution to determine whatever it thinks is suitable to represent its sufficient statistics (e.g. suffstats(Normal, x) can return NormalStats, and suffstats(Dirichlet, x) can return DirichletStats, it is all up to the people who implement the distribution to decide the type of the sufficient statistics, and implement a fit_mle_on_stats method that accepts it).

simonbyrne commented 11 years ago

That makes sense, it should make it easier to create a method for conjugate updating of poserteriors.

You could just dispatch fit_mle on the NormalStats type?

johnmyleswhite commented 11 years ago

I like this idea in principle: having an abstract type SufficientStatistics and concrete types like NormalStats seems like a good solution to me. My one concern is that the sufficient statistics may not be a good tool for computing MLE's in a numerically stable way. See, for example, this description of the problems with fitting a normal using the sum of squares: http://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/

lindahua commented 11 years ago

@johnmyleswhite I understand the numerical issues of using E(x^2) - (E(x))^2 to compute variance.

The fit_mle method that uses sufficient statistics is a fall back method. For a particular distribution, say Normal, one can still override its behavior by implementing a specialized version.

That being said, except for some special cases, this fallback works pretty well.

lindahua commented 11 years ago

This approach had additional advantage: it paves the way towards implementing generic MAP estimation (with conjugate prior) -- which is basically add up the prior param and the sum of sufficient statistics.

johnmyleswhite commented 11 years ago

I definitely see the virtues of this approach. A while back, I started thinking about how to make a consistent interface for conjugate prior updating: https://github.com/johnmyleswhite/ConjugatePriors.jl

I say let's go ahead with this idea, but make sure we understand how the fallback behaves for each distribution.

Even if we didn't use the sufficient statistics for MLE's, I think having a consistent interface to their computation is a great idea and fits perfectly with our goal of exposing theoretical objects as functions on distribution types.

lindahua commented 11 years ago

@johnmyleswhite Your conjugate prior work looks nice. May I suggest bringing the development of conjugate prior codes to within Distributions.jl?

Clearly, conjugate priors depend on distributions, while part of this package (e.g. MAP estimation) may benefit from the conjugate prior work. They are closely coupled.

I am not completely sure about the best approach to dealing with conjugate pairs, but we may discuss about this.

lindahua commented 11 years ago

Implemented by 517cf8f.

johnmyleswhite commented 11 years ago

I'll make a WIP PR for the conjugate priors so that we can discuss their interface and implementation.

lindahua commented 11 years ago

That would be great!

simonbyrne commented 11 years ago

Just out of curiousity, is there any reason why you couldn't use mean & variance as the sufficient statistics (instead of the natural statistics) for the normal distribution?

lindahua commented 11 years ago

What constitutes sufficient statistics is actually very well defined for exponential family distributions (e.g. Normal). Generally, a exponential family distribution can be expressed as

p(x; θ) = exp(η(θ)' T(x) - A(θ) )

Here, T(x) is called sufficient statistics. Particularly, for Normal distribution, the sufficient statistics consists of two componets, as (x, x^2) (this can be easily seen by expanding the pdf)

In this design, we compute the sum or weighted sum of T(x) and store it in a sufficient statistics instance, which can be used in both MLE and MAP estimation. The reason that I chose to store the sum (instead of the mean) is that the sum can be directly used in MAP estimation, and one can easily derive the mean from the sum for MLE.

This way works for many distributions. However, for the special case of Normal distribution, this is not as numerical stable as directly computing the mean and covariance for MLE, so I override the fit_mle method for Normal. That being said, the sum of x and x^2 is still very useful in MAP estimation.

johnmyleswhite commented 11 years ago

I mean be missing something, but I think Simon's point is that a distribution's sufficient statistics are not unique: the exponential family natural statistics we're using were made popular a good deal after Fisher's initial definition of sufficiency, which only requires that the likelihood function satisfy a simple invariance property after conditioning on a set of sufficient statistics.

simonbyrne commented 11 years ago

Sorry, my question may have been a bit terse. @johnmyleswhite is correct: my intended question was given that (mean,variance) and the natural statistics (sum(x), sum(x.^2)) are both sufficient, is there any reason to prefer the natural statistics?

@lindahua, you mentioned that the sum is slightly easier to work with for MAP estimation (as you don't need to keep multiplying and dividing by n): in that case, could we use (sum(x), sum(x - mean(x)).^2)? These are also sufficient, and can be computed in a numerically stable way using a slightly modified version of Welford's algorithm:

function suffstats(x)
    m_prev = s = x[1]
    ss = 0.0
    for i = 2:length(x)
        s += x[i]
        m_next = s/i
        ss += (x[i]-m_prev)*(x[i]-m_next)
        m_prev = m_next
    end
    (s,ss)
end
lindahua commented 11 years ago

@simonbyrne This seems a good suggestion. I am going to try the approach of keeping sum(x) and sum((x - μ)^2), and your code that can compute both in one pass is useful.

lindahua commented 11 years ago

I compare the performance of this vs the traditional way that first compute μ and then the sum of squared differences, as below

function ssq1(x::Array)
    n = length(x)

    # compute μ
    s = x[1]
    for i = 2:n
        @inbounds s += x[i]
    end
    μ = s / n

    # compute sum((x - μ)^2)
    ss = abs2(x[1] - μ)
    for i = 2:n
        @inbounds ss += abs2(x[i] - μ)
    end
    ss
end

function ssq2(x::Array)
    n = length(x)
    m_prev = s = x[1]
    ss = 0.0
    for i = 2:n
        @inbounds xi = x[i]
        s += xi
        m_next = s / i
        ss += (xi - m_prev) * (xi - m_next)
        m_prev = m_next
    end
    ss
end

x = randn(10)
r1 = ssq1(x)
r2 = ssq2(x)

x = randn(10^7)
@time ssq1(x)
@time ssq2(x)

Results:

elapsed time: 0.016012012 seconds (7092 bytes allocated)
elapsed time: 0.089747646 seconds (64 bytes allocated)

The traditional way is still considerably faster (5.6x). I will go with your suggestion, but using the traditional implementation.

johnmyleswhite commented 11 years ago

I think the traditional implementation is the right way to go, but we may want to expose a mechanism for Welford's algorithm in the way that Base exposes sum and sum_kbn.

simonbyrne commented 11 years ago

I'm beginning to get less surprised when the simple but seemingly inefficient algorithm is faster: I've tried a couple of times to tweak the alias table implementation to get rid of the queue arrays by iterating, but it always performs worse.

lindahua commented 11 years ago

Implemented in 62315b7.

dmbates commented 11 years ago

By the way, a statistician would not write the sample mean as μ. μ is a parameter of the distribution whereas the sample mean is the observed value of a statistic. It may seem like a picky distinction but it is one way of not getting totally confused about properties of estimators.

lindahua commented 11 years ago

@dmbates Thanks for the comments. I understand the distinction between expectation and sample mean. I was just becoming more fond of the greek letters in writing codes. Commit f8eccf6 changes μ to m to denote the sample mean.