Open alexandrebouchard opened 1 year ago
Maybe investigate first if that's not the line doing all the allocs: https://github.com/Julia-Tempering/Pigeons.jl/blob/a3afb4f339ff2f2917b7c2be82ce85884792a489/src/recorders/OnlineStateRecorder.jl#L73
Nevermind last comment, can't workaround it anyways. Really have to have a vectorized version I think.
@alexandrebouchard it turns out that StructArrays makes this trivial. For example,
using StructArrays
series_vec = StructVector(Series(Mean(), Variance()) for i in 1:3)
data = [i .+ randn(1000) for i in 1:3]
fit!.(series_vec, data) # note the broadcasting dot
which works as expected
julia> series_vec
3-element StructArray(::Vector{Tuple{Mean{Float64, EqualWeight}, Variance{Float64, Float64, EqualWeight}}}) with eltype Series{Number, Tuple{Mean{Float64, EqualWeight}, Variance{Float64, Float64, EqualWeight}}}:
Series
├─ Mean: n=1_000 | value=0.916949
└─ Variance: n=1_000 | value=0.985739
Series
├─ Mean: n=1_000 | value=1.96672
└─ Variance: n=1_000 | value=1.00922
Series
├─ Mean: n=1_000 | value=3.07938
└─ Variance: n=1_000 | value=0.962972
nice find!!
At the moment, OnlineStats uses many small OnlineStats objects to deal with vectors: see https://joshday.github.io/OnlineStats.jl/latest/api/#OnlineStatsBase.Group
Consider the situation where the state space is a single vector. At dimensions around 10000+ the creation of the recorder based on OnlineStatsBase.Group starts to become the bottleneck.
We would like something like Mean(Vector{Float64}) instead of what is used at https://github.com/Julia-Tempering/Pigeons.jl/blob/a3afb4f339ff2f2917b7c2be82ce85884792a489/src/recorders/OnlineStateRecorder.jl#L75 unfortunately OnlineStats does not seem to have one. We might have to roll our own.