JuliaLang / LinearAlgebra.jl

Julia Linear Algebra standard library
Other
13 stars 0 forks source link

`sum` over a `Diagonal(::Vector{<:SMatrix})` requires runtime dispatch #1016

Closed jishnub closed 11 months ago

jishnub commented 1 year ago

It'd be great to not have runtime dispatch in such cases:

julia> VERSION
v"1.10.0-alpha1"

julia> D = Diagonal(zeros(SMatrix{2,2,Int,4},2))
2×2 Diagonal{SMatrix{2, 2, Int64, 4}, Vector{SMatrix{2, 2, Int64, 4}}}:
 [0 0; 0 0]      ⋅     
     ⋅       [0 0; 0 0]

julia> @report_opt sum(D)
═════ 6 possible errors found ═════
┌ sum(a::Diagonal{SMatrix{2, 2, Int64, 4}, Vector{SMatrix{2, 2, Int64, 4}}}) @ Base ./reducedim.jl:996
│┌ sum(a::Diagonal{SMatrix{2, 2, Int64, 4}, Vector{SMatrix{2, 2, Int64, 4}}}; dims::Colon, kw::@Kwargs{}) @ Base ./reducedim.jl:996
││┌ _sum(A::Diagonal{SMatrix{2, 2, Int64, 4}, Vector{SMatrix{2, 2, Int64, 4}}}, ::Colon) @ LinearAlgebra /cache/build/default-amdci5-1/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:889
│││┌ sum(a::Vector{SMatrix{2, 2, Int64, 4}}) @ Base ./reducedim.jl:996
││││┌ sum(a::Vector{SMatrix{2, 2, Int64, 4}}; dims::Colon, kw::Base.Pairs) @ Base ./reducedim.jl:996
│││││┌ merge(a::@NamedTuple{}, itr::Base.Pairs) @ Base ./namedtuple.jl:364
││││││┌ iterate(::Base.Pairs) @ Base.Iterators ./iterators.jl:301
│││││││┌ _pairs_elt(p::Base.Pairs, idx::Any) @ Base.Iterators ./iterators.jl:294
││││││││ runtime dispatch detected: (%4::Any)[idx::Any]::Any
│││││││└────────────────────
│││││┌ kwcall(::NamedTuple, ::typeof(Base._sum), a::Vector{SMatrix{2, 2, Int64, 4}}, ::Colon) @ Base ./reducedim.jl:1000
││││││┌ pairs(nt::NamedTuple) @ Base.Iterators ./iterators.jl:279
│││││││┌ (Base.Pairs{Symbol})(data::NamedTuple, itr::Tuple{Vararg{Symbol}}) @ Base ./essentials.jl:343
││││││││┌ eltype(::Type{A} where A<:NamedTuple) @ Base ./namedtuple.jl:237
│││││││││┌ nteltype(::Type{NamedTuple{names, T}} where names) where T<:Tuple @ Base ./namedtuple.jl:239
││││││││││┌ eltype(t::Type{<:Tuple{Vararg{E}}}) where E @ Base ./tuple.jl:208
│││││││││││┌ _compute_eltype(t::Type{<:Tuple{Vararg{E}}} where E) @ Base ./tuple.jl:231
││││││││││││┌ afoldl(op::Base.var"#54#55", a::Any, bs::Vararg{Any}) @ Base ./operators.jl:542
│││││││││││││┌ (::Base.var"#54#55")(a::Any, b::Any) @ Base ./tuple.jl:235
││││││││││││││┌ promote_typejoin(a::Any, b::Any) @ Base ./promotion.jl:172
│││││││││││││││┌ typejoin(a::Any, b::Any) @ Base ./promotion.jl:127
││││││││││││││││ runtime dispatch detected: Base.UnionAll(%403::Any, %405::Any)::Any
│││││││││││││││└────────────────────
││││││││││││┌ afoldl(op::Base.var"#54#55", a::Any, bs::Vararg{Any}) @ Base ./operators.jl:543
│││││││││││││┌ (::Base.var"#54#55")(a::Type, b::Any) @ Base ./tuple.jl:235
││││││││││││││┌ promote_typejoin(a::Type, b::Any) @ Base ./promotion.jl:172
│││││││││││││││┌ typejoin(a::Type, b::Any) @ Base ./promotion.jl:127
││││││││││││││││ runtime dispatch detected: Base.UnionAll(%398::Any, %400::Any)::Any
│││││││││││││││└────────────────────
│││┌ sum(a::Vector{SMatrix{2, 2, Int64, 4}}) @ Base ./reducedim.jl:996
││││ failed to optimize due to recursion: sum(::Vector{SMatrix{2, 2, Int64, 4}})
│││└────────────────────
││┌ _sum(A::Diagonal{SMatrix{2, 2, Int64, 4}, Vector{SMatrix{2, 2, Int64, 4}}}, ::Colon) @ LinearAlgebra /cache/build/default-amdci5-1/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:889
│││ failed to optimize due to recursion: Base._sum(::Diagonal{SMatrix{2, 2, Int64, 4}, Vector{SMatrix{2, 2, Int64, 4}}}, ::Colon)
││└────────────────────
│┌ sum(a::Diagonal{SMatrix{2, 2, Int64, 4}, Vector{SMatrix{2, 2, Int64, 4}}}; dims::Colon, kw::@Kwargs{}) @ Base ./reducedim.jl:996
││ failed to optimize due to recursion: Base.var"#sum#827"(::Colon, ::Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, ::typeof(sum), ::Diagonal{SMatrix{2, 2, Int64, 4}, Vector{SMatrix{2, 2, Int64, 4}}})
│└────────────────────

This issue exists on v1.9.2 as well as on nightly v"1.11.0-DEV.78"

jishnub commented 11 months ago

This is fixed on

julia> VERSION
v"1.11.0-DEV.1030"

julia> @report_opt sum(D)
No errors detected

I doubt if it will be possible to backport this to v1.10, so I'll close this for now