JuliaStats / StatsBase.jl

Basic statistics for Julia
Other
585 stars 190 forks source link

Pooled variance #478

Open carstenbauer opened 5 years ago

carstenbauer commented 5 years ago

I couldn't find a function that calculates the combined/pooled variance of two (or more) datasets. I think it would be great to offer this.

Given two samples x1, x2 the combined variance is the variance of the concatenated sample vcat(x1,x2). I came up with the following implementation:

"""

    combined_mean_and_var(x1, x2) -> meanc, varc

Given two samples `x1`,`x2` calculates the mean and variance of the
concatenated sample.
"""
function combined_mean_and_var(x1::AbstractVector{<:Number}, x2::AbstractVector{<:Number})
    n1, n2 = length(x1), length(x2)
    μ1, μ2 = mean(x1), mean(x2)
    v1, v2 = var(x1), var(x2)
    return combined_mean_and_var(n1, μ1, v1, n2, μ2, v2)
end

"""

    combined_mean_and_var(n1, μ1, v1, n2, μ2, v2) -> meanc, varc

Given two samples characterized by their lengths `n1`, `n2`, 
their means `μ1`, `μ2`, and their variances `v1`, `v2`,
calculates the combined (or pooled) mean and variance of the
concatenated sample.
"""
function combined_mean_and_var(n1::Integer, μ1::Number, v1::Number,
                               n2::Integer, μ2::Number, v2::Number)
    meanc = (n1 * μ1 + n2 * μ2) / (n1 + n2)

    # Based on https://www.emathzone.com/tutorials/basic-statistics/combined-variance.html,
    # including Robert Matheson's comment and adding abs for complex number support.
    varc = ((n1-1)*v1 + (n2-1)*v2 + n1*abs2(μ1 - meanc) + n2*abs2(μ2 - meanc)) /
                                (n1 + n2 - 1)
    return meanc, varc
end

"""

    combined_mean_and_var(xs...) -> meanc, varc

Calculates the combined mean and variance of the concatenated sample `vcat(xs...)`.
"""
function combined_mean_and_var(xstuple::AbstractVector{<:Number}...)
    # Could perhaps be accelerated by using @generated functions and producing
    # explicit (hand written) formulas?
    xs = collect(xstuple)
    ns = map(length, xs)
    μs = map(mean, xs)
    vs = map(var, xs)
    return combined_mean_and_var(ns, μs, vs)
end

"""

    combined_mean_and_var(ns, μs, vs) -> meanc, varc

Given N samples characterized by their lengths `ns`,
their means `μs`, and their variances `vs`,
calculates the combined (or pooled) mean and variance of the
overall (concatenated) sample.
"""
function combined_mean_and_var(ns::AbstractVector{<:Integer},
                               μs::AbstractVector{<:Number},
                               vs::AbstractVector{<:Number})
    nsum = sum(ns)
    meanc = dot(ns, μs) / nsum
    varc = sum((ns .- 1) .* vs + ns .* abs2.(μs .- meanc)) / (nsum - 1)
    return meanc, varc
end

and some tests

using Test

# Explicit version for three samples for testing
function combined_mean_and_var_three(x1, x2, x3)
    n1, n2, n3 = length(x1), length(x2), length(x3)
    μ1, μ2, μ3 = mean(x1), mean(x2), mean(x3)
    v1, v2, v3 = var(x1), var(x2), var(x3)

    meanc12, varc12 = combined_mean_and_var(n1, μ1, v1, n2, μ2, v2)
    n12 = n1 + n2
    return combined_mean_and_var(n12, meanc12, varc12, n3, μ3, v3)
end

function test_combined_mean_and_var()
    test_two = (x1,x2) -> begin
        xc = vcat(x1,x2)
        meanc, varc = combined_mean_and_var(x1,x2)
        @test abs(mean(xc) - meanc) < 1e-12
        @test abs(var(xc) - varc) < 1e-12
    end

    test_three = (x1,x2,x3) -> begin
        xc = vcat(x1,x2,x3)
        meanc, varc = combined_mean_and_var(x1,x2,x3)
        @test abs(mean(xc) - meanc) < 1e-12
        @test abs(var(xc) - varc) < 1e-12

        meanc_three, varc_three = combined_mean_and_var_three(x1,x2,x3)
        @test isapprox(meanc, meanc_three)
        @test isapprox(varc, varc_three)
    end

    test_N = (xs...) -> begin
        xc = vcat(xs...)
        meanc, varc = combined_mean_and_var(xs...)
        @test abs(mean(xc) - meanc) < 1e-12
        @test abs(var(xc) - varc) < 1e-12
    end

    test_N_moments = (ns, μs, vs, mean_exact, var_exact) -> begin
        meanc, varc = combined_mean_and_var(ns, μs, vs)
        @test abs(mean_exact - meanc) < 1e-12
        @test abs(var_exact - varc) < 1e-12
    end

    @testset "Combined Mean and Variance" begin
        @testset "Two samples" begin
            x1 = rand(30_000)
            x2 = rand(20_000)
            test_two(x1,x2)

            x1 = rand(ComplexF64, 30_000)
            x2 = rand(ComplexF64, 20_000)
            test_two(x1,x2)
        end

        @testset "Three samples" begin
            x1 = rand(30_000)
            x2 = rand(20_000)
            x3 = rand(40_000)
            test_three(x1, x2, x3)

            x1 = rand(ComplexF64, 30_000)
            x2 = rand(ComplexF64, 20_000)
            x3 = rand(ComplexF64, 40_000)
            test_three(x1, x2, x3)
        end

        @testset "N samples" begin
            lengths = [30_000, 20_000, 40_000]
            N = 5

            xs = [rand(Float64, rand(lengths)) for _ in 1:N]
            test_N(xs...)

            xs = [rand(ComplexF64, rand(lengths)) for _ in 1:N]
            test_N(xs...)
        end

        @testset "N samples (ns, μs, vs)" begin
            lengths = [30_000, 20_000, 40_000]
            N = 5

            xs = [rand(Float64, rand(lengths)) for _ in 1:N]
            xc = vcat(xs...)
            test_N_moments(length.(xs), mean.(xs), var.(xs), mean(xc), var(xc))

            xs = [rand(ComplexF64, rand(lengths)) for _ in 1:N]
            xc = vcat(xs...)
            test_N_moments(length.(xs), mean.(xs), var.(xs), mean(xc), var(xc))
        end
    end
    nothing
end

Do you guys think it'd be worth adding something like this?

nalimilan commented 5 years ago

var(Iterators.flatten([x1, x2])) already works, so I'm not sure we need to add anything. One of Julia's strengths is that features can easily be combined that way.

carstenbauer commented 5 years ago

Your approach is much slower though,

julia> x1 = rand(1_000_000);

julia> x2 = rand(1_000_000);

julia> @btime var(Iterators.flatten([$x1, $x2]));
  12.272 ms (2 allocations: 112 bytes)

julia> @btime combined_mean_and_var($x1, $x2);
  2.066 ms (0 allocations: 0 bytes)

More importantly, using combined_mean_var(ns, μs, vars) one can calculate the combined variance from the lengths, means, and variances alone (one doesn't need the full time series). I don't see a simple replacement for that.

nalimilan commented 5 years ago

If we really care about performance we could provide a custom method for var(::Iterators.Flatten). There's also CatViews.jl. (Maybe flatten could get faster too.)

The method taking only the summary statistics is a completely different beast. I'm not sure about that since we don't include any function like that currently AFAICT. Are there other examples of this kind of thing in stats? Do other programs support this (and how)?

ararslan commented 5 years ago

FYI, I added an implementation of covariance matrix pooling in HypothesisTests, as it's used for a couple of multivariate tests. See https://github.com/JuliaStats/HypothesisTests.jl/blob/master/src/common.jl#L61-L67.