joshday / OnlineStats.jl

⚡ Single-pass algorithms for statistics
https://joshday.github.io/OnlineStats.jl/latest/
MIT License
831 stars 62 forks source link

Possible type instability in `OnlineStatsBase.jl` #265

Closed nic-barbara closed 1 year ago

nic-barbara commented 1 year ago

In OnlineStatsBase.jl, why are some statistics types subtyped with OnlineStat{Number}? For example:

mutable struct Mean{T,W} <: OnlineStat{Number}
    μ::T
    weight::W
    n::Int
end
Mean(T::Type{<:Number} = Float64; weight = EqualWeight()) = Mean(zero(T), weight, 0)

Is there a reason we can't have mutable struct Mean{T,W} <: OnlineStat{T} instead? This means that when input() is called on statistics like Mean() it will always return Number instead of the actual input type (eg: Float32). The issue appears to affect Mean, Moments, Sum, and variance.


I noticed this while playing around with a Mean/Stdev filter. My original code is as follows (and feel free to offer any suggestions on better/more efficient ways to do this, I'm new to this package).

using BenchmarkTools
using OnlineStatsBase

mutable struct MeanStdFilter{T}
    nu::Int
    tracker::OnlineStat
end

function MeanStdFilter(nu::Int; T::DataType=Float32)
    s = [Series(Mean(T), Variance(T)) for _ in 1:nu]
    return MeanStdFilter{T}(nu, Group(s...))
end

function _get_mean_var(m::MeanStdFilter{T}) where T
    vals = value.(value(m.tracker))
    return reinterpret(reshape, T, collect(vals))
end

function (m::MeanStdFilter)(x::AbstractVector)
    fit!(m.tracker, x)
    μσ2 = _get_mean_var(m)
    return (x .- μσ2[1,:]) ./ sqrt.(μσ2[2,:])
end

# Test runtime
nu = 4
T = Float32
m = MeanStdFilter(nu; T)

# @btime m(randn(T,nu));
@btime _get_mean_var(m);

Running with T = Float32 I get:

1.014 μs (18 allocations: 608 bytes)

and with T = Float64 it increases to:

549.342 ns (6 allocations: 480 bytes)

I suspect this is to do with having to convert Float64 to Float32 at some point in the pipeline because of the issue raised above.

Thanks in advance for any help!

nic-barbara commented 1 year ago

Actually given this is an issue to OnlineStatsBase.jl I'll move the discussion over to there. Apologies for the inconvenience.