gdkrmr / WeightedOnlineStats.jl

Weighted version of OnlineStats.jl
MIT License
10 stars 4 forks source link

fit!(WeightedMean(), X, W) gives the unweighted mean #51

Closed melwey closed 1 year ago

melwey commented 1 year ago

Is this the intended behaviour?

X = ones(5);
W = [.5,.5,1,1,0];
o = fit!(WeightedMean(), X, W)
# I would expect
# WeightedMean: ∑wᵢ=3.0 | value=0.6
# but I get
# WeightedMean: ∑wᵢ=3.0 | value=1.0
# to get the weighted mean, I need to do
o.µ * o.W

If not, this could fix it: In WeightedMean.jl


function OnlineStatsBase._fit!(o::WeightedMean{T}, x, w) where T
    xx = convert(T, x)
    ww = convert(T, w)

    o.n += 1
    o.W = smooth(o.W, ww, T(1) / o.n)
    # instead of
    # o.µ = smooth(o.μ, xx, ww / (o.W * o.n))
    # do 
    o.μ = smooth(o.μ, xx * ww, T(1) / o.n) 

    o
end

function OnlineStatsBase._merge!(o::MyWeightedMean{T}, o2::MyWeightedMean) where T
    o2_W = convert(T, o2.W)
    o2_μ = convert(T, o2.μ)

    o.n += o2.n
    o.W = smooth(o.W, o2_W, o2.n / o.n)
    # instead of
    # o.μ = smooth(o.μ, o2_μ, (o2_W * o2.n) / (o.W * o.n))
    # do
    o.μ = smooth(o.μ, o2_μ, o2.n / o.n)

    o
end
melwey commented 1 year ago

Never mind. I got it. Sorry

gdkrmr commented 1 year ago

in case someone stumbles over this: you have to use mean(o) to get the numerical value of the mean.