Is there a reason we can't have mutable struct Mean{T,W} <: OnlineStat{T} instead? This means that when input() is called on statistics like Mean() it will always return Number instead of the actual input type (eg: Float32). The issue appears to affect Mean, Moments, Sum, and variance.
I noticed this while playing around with a Mean/Stdev filter. My original code is as follows (and feel free to offer any suggestions on better/more efficient ways to do this, I'm new to this package).
using BenchmarkTools
using OnlineStatsBase
mutable struct MeanStdFilter{T}
nu::Int
tracker::OnlineStat
end
function MeanStdFilter(nu::Int; T::DataType=Float32)
s = [Series(Mean(T), Variance(T)) for _ in 1:nu]
return MeanStdFilter{T}(nu, Group(s...))
end
function _get_mean_var(m::MeanStdFilter{T}) where T
vals = value.(value(m.tracker))
return reinterpret(reshape, T, collect(vals))
end
function (m::MeanStdFilter)(x::AbstractVector)
fit!(m.tracker, x)
μσ2 = _get_mean_var(m)
return (x .- μσ2[1,:]) ./ sqrt.(μσ2[2,:])
end
# Test runtime
nu = 4
T = Float32
m = MeanStdFilter(nu; T)
# @btime m(randn(T,nu));
@btime _get_mean_var(m);
Running with T = Float32 I get:
1.014 μs (18 allocations: 608 bytes)
and with T = Float64 it increases to:
549.342 ns (6 allocations: 480 bytes)
I suspect this is to do with having to convert Float64 to Float32 at some point in the pipeline because of the issue raised above.
In
OnlineStatsBase.jl
, why are some statistics types subtyped withOnlineStat{Number}
? For example:Is there a reason we can't have
mutable struct Mean{T,W} <: OnlineStat{T}
instead? This means that wheninput()
is called on statistics likeMean()
it will always returnNumber
instead of the actual input type (eg:Float32
). The issue appears to affectMean
,Moments
,Sum
, andvariance
.I noticed this while playing around with a Mean/Stdev filter. My original code is as follows (and feel free to offer any suggestions on better/more efficient ways to do this, I'm new to this package).
Running with
T = Float32
I get:and with
T = Float64
it increases to:I suspect this is to do with having to convert
Float64
toFloat32
at some point in the pipeline because of the issue raised above.Thanks in advance for any help!