joshday / OnlineStats.jl

⚡ Single-pass algorithms for statistics
https://joshday.github.io/OnlineStats.jl/latest/
MIT License
831 stars 62 forks source link

_fit! on AutoCov is not type stable #269

Closed vandenman closed 9 months ago

vandenman commented 9 months ago

Here is a MWE:

import OnlineStats # v1.6.3
o = OnlineStats.AutoCov(5)
OnlineStats.fit!(o, 1.0)
@code_warntype OnlineStats._fit!(o, 1.0)

returns

MethodInstance for OnlineStatsBase._fit!(::OnlineStats.AutoCov{Float64, Float64}, ::Float64)
  from _fit!(o::OnlineStats.AutoCov{T}, y) where T @ OnlineStats ~/.julia/packages/OnlineStats/Kiiyv/src/stats/stats.jl:66
Static Parameters
  T = Float64
Arguments
  #self#::Core.Const(OnlineStatsBase._fit!)
  o::OnlineStats.AutoCov{Float64, Float64}
  y::Float64
Locals
  @_4::Union{Nothing, Tuple{Int64, Int64}}
  val@_5::Nothing
  val@_6::Any
  @_7::Union{Nothing, Tuple{Int64, Int64}}
  γ::Any
  val@_9::Float64
  k@_10::Int64
  lagk::Float64
  γk::Float64
  k@_13::Int64
  @_14::Float64
  @_15::Float64
Body::Nothing
1 ──       Core.NewvarNode(:(@_4))
│          Core.NewvarNode(:(val@_5))
│          Core.NewvarNode(:(val@_6))
│    %4  = Base.getproperty(o, :v)::OnlineStatsBase.Variance{Float64}
│    %5  = Base.getproperty(%4, :weight)::Any
│    %6  = Base.getproperty(o, :v)::OnlineStatsBase.Variance{Float64}
│    %7  = Base.getproperty(%6, :n)::Int64
│    %8  = (%7 + 1)::Int64
│          (γ = (%5)(%8))
│    %10 = Base.getproperty(o, :v)::OnlineStatsBase.Variance{Float64}
│          OnlineStats._fit!(%10, y)
│    %12 = Base.getproperty(o, :lag)::OnlineStatsBase.CircBuff{Float64}
│          OnlineStats._fit!(%12, y)
│    %14 = Base.getproperty(o, :wlag)::OnlineStatsBase.CircBuff{Float64}
│          OnlineStats._fit!(%14, γ)
│    %16 = Base.getproperty(o, :m2)::Vector{Float64}
│    %17 = OnlineStats.length(%16)::Int64
│    %18 = (2:%17)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(2), Int64])
│    %19 = OnlineStats.reverse(%18)::Core.PartialStruct(StepRange{Int64, Int64}, Any[Int64, Core.Const(-1), Int64])
│          (@_7 = Base.iterate(%19))
│    %21 = (@_7 === nothing)::Bool
│    %22 = Base.not_int(%21)::Bool
└───       goto #4 if not %22
2 ┄─ %24 = @_7::Tuple{Int64, Int64}
│          (k@_10 = Core.getfield(%24, 1))
│    %26 = Core.getfield(%24, 2)::Int64
│          nothing
│    %28 = Base.getproperty(o, :m1)::Vector{Float64}
│    %29 = (k@_10 - 1)::Int64
│    %30 = Base.getindex(%28, %29)::Float64
│    %31 = Base.getproperty(o, :m1)::Vector{Float64}
│          Base.setindex!(%31, %30, k@_10)
│          (val@_9 = %30)
│          nothing
│          val@_9
│          (@_7 = Base.iterate(%19, %26))
│    %37 = (@_7 === nothing)::Bool
│    %38 = Base.not_int(%37)::Bool
└───       goto #4 if not %38
3 ──       goto #2
4 ┄─       nothing
│    %42 = Base.getproperty(o, :m1)::Vector{Float64}
│    %43 = Base.getindex(%42, 1)::Float64
│    %44 = OnlineStats.smooth(%43, y, γ)::Any
│    %45 = Base.getproperty(o, :m1)::Vector{Float64}
│          Base.setindex!(%45, %44, 1)
│          (val@_6 = %44)
│          nothing
│          val@_6
│          nothing
│    %51 = Base.getproperty(o, :m1)::Vector{Float64}
│    %52 = OnlineStats.length(%51)::Int64
│    %53 = (1:%52)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│          (@_4 = Base.iterate(%53))
│    %55 = (@_4 === nothing)::Bool
│    %56 = Base.not_int(%55)::Bool
└───       goto #13 if not %56
5 ┄─       Core.NewvarNode(:(lagk))
│          Core.NewvarNode(:(γk))
│    %60 = @_4::Tuple{Int64, Int64}
│          (k@_13 = Core.getfield(%60, 1))
│    %62 = Core.getfield(%60, 2)::Int64
│    %63 = k@_13::Int64
│    %64 = Base.getproperty(o, :wlag)::OnlineStatsBase.CircBuff{Float64}
│    %65 = OnlineStats.length(%64)::Int64
│    %66 = (%63 ≤ %65)::Bool
└───       goto #7 if not %66
6 ── %68 = Base.getproperty(o, :wlag)::OnlineStatsBase.CircBuff{Float64}
│          (@_14 = Base.getindex(%68, k@_13))
└───       goto #8
7 ──       (@_14 = 0.0)
8 ┄─       (γk = @_14)
│    %73 = k@_13::Int64
│    %74 = Base.getproperty(o, :lag)::OnlineStatsBase.CircBuff{Float64}
│    %75 = OnlineStats.length(%74)::Int64
│    %76 = (%73 ≤ %75)::Bool
└───       goto #10 if not %76
9 ── %78 = Base.getproperty(o, :lag)::OnlineStatsBase.CircBuff{Float64}
│          (@_15 = Base.getindex(%78, k@_13))
└───       goto #11
10 ─       (@_15 = OnlineStats.zero($(Expr(:static_parameter, 1))))
11 ┄       (lagk = @_15)
│    %83 = Base.getproperty(o, :cross)::Vector{Float64}
│    %84 = Base.getindex(%83, k@_13)::Float64
│    %85 = (y * lagk)::Float64
│    %86 = OnlineStats.smooth(%84, %85, γk)::Float64
│    %87 = Base.getproperty(o, :cross)::Vector{Float64}
│          Base.setindex!(%87, %86, k@_13)
│    %89 = Base.getproperty(o, :m2)::Vector{Float64}
│    %90 = Base.getindex(%89, k@_13)::Float64
│    %91 = OnlineStats.smooth(%90, y, γk)::Float64
│    %92 = Base.getproperty(o, :m2)::Vector{Float64}
│          Base.setindex!(%92, %91, k@_13)
│          (@_4 = Base.iterate(%53, %62))
│    %95 = (@_4 === nothing)::Bool
│    %96 = Base.not_int(%95)::Bool
└───       goto #13 if not %96
12 ─       goto #5
13 ┄       (val@_5 = nothing)
│          nothing
└───       return val@_5

A call to @profview shows that the lines highlighted in red are type unstable: image

Perhaps the fields of AutoCov should be fully typed? For example,

struct AutoCov{T, W, R} <: OnlineStat{Number}
    cross::Vector{Float64}
    m1::Vector{Float64}
    m2::Vector{Float64}
-   lag::CircBuff{T}
-   wlag::CircBuff{Float64}
-   v::Variance{W}
+   lag::CircBuff{T, true}
+   wlag::CircBuff{Float64, true}
+   v::Variance{T, T, W}
end

although I haven't tested this. If you think this is a reasonable solution, I'd be happy to contribute a PR!

joshday commented 9 months ago

If you think this is a reasonable solution, I'd be happy to contribute a PR!

Yes, please make the PR!

Thanks for finding the issue. This has happened a few times when I've added parameters to a type and forgotten about some of the "downstream" types that contain it.