Open marius311 opened 6 years ago
julia> methods(Base.Broadcast.BroadcastStyle, Tuple{Type{<:Foo}})
# 12 methods for generic function "(::Type)":
[1] Base.Broadcast.BroadcastStyle(::Type{Union{}}) in Base.Broadcast at broadcast.jl:46
[2] Base.Broadcast.BroadcastStyle(::Type{#s608} where #s608<:Tuple) in Base.Broadcast at broadcast.jl:43
[3] Base.Broadcast.BroadcastStyle(::Type{T}) where T<:Union{Bidiagonal, Diagonal, LowerTriangular, SymTridiagonal, Tridiagonal, UnitLowerTriangular, UnitUpperTriangular, UpperTriangular} in LinearAlgebra at /home/hol/Projects/julia-master/usr/share/julia/stdlib/v0.7/LinearAlgebra/src/structuredbroadcast.jl:12
[4] Base.Broadcast.BroadcastStyle(::Type{#s608} where #s608<:SparseVector) in SparseArrays.HigherOrderFns at /home/hol/Projects/julia-master/usr/share/julia/stdlib/v0.7/SparseArrays/src/higherorderfns.jl:40
[5] Base.Broadcast.BroadcastStyle(::Type{#s608} where #s608<:SparseMatrixCSC) in SparseArrays.HigherOrderFns at /home/hol/Projects/julia-master/usr/share/julia/stdlib/v0.7/SparseArrays/src/higherorderfns.jl:41
[6] Base.Broadcast.BroadcastStyle(::Type{#s608} where #s608<:(Adjoint{T,#s607} where #s607<:Union{SparseMatrixCSC, SparseVector} where T)) in SparseArrays.HigherOrderFns at /home/hol/Projects/julia-master/usr/share/julia/stdlib/v0.7/SparseArrays/src/higherorderfns.jl:69
[7] Base.Broadcast.BroadcastStyle(::Type{#s608} where #s608<:(Transpose{T,#s607} where #s607<:Union{SparseMatrixCSC, SparseVector} where T)) in SparseArrays.HigherOrderFns at /home/hol/Projects/julia-master/usr/share/julia/stdlib/v0.7/SparseArrays/src/higherorderfns.jl:70
[8] Base.Broadcast.BroadcastStyle(::Type{#s608} where #s608<:(RowVector{T,#s607} where #s607<:(Array{T,1} where T))) where T in SparseArrays.HigherOrderFns at /home/hol/Projects/julia-master/usr/share/julia/stdlib/v0.7/SparseArrays/src/deprecated.jl:214
[9] Base.Broadcast.BroadcastStyle(::Type{#s608} where #s608<:AbstractArray{T,N}) where {T, N} in Base.Broadcast at broadcast.jl:99
[10] Base.Broadcast.BroadcastStyle(::Type{#s608} where #s608<:(Base.Broadcast.Broadcasted{S,Axes,F,Args} where Args<:Tuple where F where Axes)) where S<:Union{Base.Broadcast.Unknown, Nothing} in Base.Broadcast at broadcast.jl:223
[11] Base.Broadcast.BroadcastStyle(::Type{#s608} where #s608<:(Base.Broadcast.Broadcasted{Style,Axes,F,Args} where Args<:Tuple where F where Axes)) where Style in Base.Broadcast at broadcast.jl:222
[12] Base.Broadcast.BroadcastStyle(::Type{#s2} where #s2<:Foo) in Main at REPL[76]:1
Type{Union{}}
is the common intersection.
I don't completely follow, sorry. Do you mean you think this is the correct behavior? Are you sure you didn't mean,
julia> methods(Base.Broadcast.BroadcastStyle, Tuple{Type{Foo}})
# 1 method for generic function "(::Type)":
[1] Base.Broadcast.BroadcastStyle(::Type{#s1} where #s1<:Foo) in Main at REPL[2]:1
which gives the correct one? I don't totally see how any of methods above except [12] can be callable on this type here.
You provide that c isa Foo
, so typeof(c)
is inferred as Type{<:Foo}
(Type{#s55} where #s55<:Foo{T} where T
as the type of %1
). Note that typeof(c)
cannot be Type{Foo}
, as Foo
is not a leaftype. But Type{Union{}}
is special-cased anyway, so something else is weird here. Also note
julia> Base.return_types(Base.Broadcast.BroadcastStyle, Tuple{Type{<:Foo}})
12-element Array{Any,1}:
Base.Broadcast.Unknown
Base.Broadcast.Style{Tuple}
LinearAlgebra.StructuredMatrixStyle{Union{}}
SparseArrays.HigherOrderFns.SparseVecStyle
SparseArrays.HigherOrderFns.SparseMatStyle
SparseArrays.HigherOrderFns.PromoteToSparse
SparseArrays.HigherOrderFns.PromoteToSparse
Base.Broadcast.DefaultArrayStyle{2}
Base.Broadcast.DefaultArrayStyle{_1} where _1
Union{}
Any
Base.Broadcast.DefaultArrayStyle{0}
So only a subset of the methods is considered (probably due to the special-casing of Type{Union{}}
), but at the same time, LinearAlgebra.StructuredMatrixStyle{Union{}}
becomes LinearAlgebra.StructuredMatrixStyle
(instead of being removed altogether, too). My wild guess would be something is sub-optimal when inference determines the applicable methods in combination with the Union
upper bound.
Thanks, that makes slightly more sense. But maybe I don't actually understand what methods
and return_types
give. To me it seems if c isa Foo
, then the the only method which could possibly be called in line 3 is [12], not any of the others, no? So why are the other 11 returned there?
Type{<:Foo}
does not necessarily need to be obtained by typeof(x)
where x isa Foo
. And it could contain Type{Union{}}
, which all those methods match. But I think that might be a red herring, anyway. This reduction might be more approachable:
julia> foo(::Type{<:Union{AbstractFloat, AbstractIrrational, Rational, Signed}}) = Val(:union4)
foo (generic function with 1 method)
julia> foo(::Type{<:Union{AbstractFloat, AbstractIrrational, Rational, Signed, Unsigned}}) = Val(:union5)
foo (generic function with 2 methods)
julia> foo(::Type{<:Ref}) = Val(:ref)
foo (generic function with 3 methods)
julia> bar(x) = foo(typeof(x))
bar (generic function with 1 method)
julia> code_warntype(bar, Tuple{Ref})
Body::Union{Val{:union5}, Val{:ref}}
1 1 ─ %1 = Main.typeof(%%x)::Type{#s55} where #s55<:Ref{T} where T │
│ %2 = Main.foo(%1)::Union{Val{:union5}, Val{:ref}} │
└── return %2
So something breaks for a union with five or more members. Notably, defining the correct method before the offending one, things work as desired:
# fresh session
julia> foo(::Type{<:Union{AbstractFloat, AbstractIrrational, Rational, Signed}}) = Val(:union4)
foo (generic function with 1 method)
julia> foo(::Type{<:Ref}) = Val(:ref)
foo (generic function with 2 methods)
julia> foo(::Type{<:Union{AbstractFloat, AbstractIrrational, Rational, Signed, Unsigned}}) = Val(:union5)
foo (generic function with 3 methods)
julia> bar(x) = foo(typeof(x))
bar (generic function with 1 method)
julia> code_warntype(bar, Tuple{Ref})
Body::Val{:ref}
1 1 ─ return :($(QuoteNode(Val{:ref}())))
So there is something up with inference here.
That said, you're better off making sure the arguments to the broadcast can be inferred anyway :stuck_out_tongue:
Ah nice, that reduction is great, that must be what's happening. And thanks for sticking with me with the explanations.
That said, you're better off making sure the arguments to the broadcast can be inferred anyway
I partially agree, but I do think there's scenarios where its not worth complexifying your code to keep things perfectly type stable, as long as the performance hit is negligible. That was the case for me here in 0.6, largely because the result of this computation (Foo .+ Foo
) fed into something for which inference was able to narrow the type down to something concrete as long as the result was inferred as Foo
and not as Any
as is happening here.
It's also happening for other types, btw:
julia> foo(x) = x .+ x
foo (generic function with 1 method)
julia> code_warntype(foo, (SparseVector,))
Body::Any
1 1 ─ %1 = :(Base.Broadcast.materialize)::Core.Compiler.Const(Base.Broadcast.materialize, false) │
│ %2 = :(Main.:+)::Core.Compiler.Const(+, false) │
│ %3 = Base.Broadcast.combine_styles(%%x, %%x)::Any │╻ broadcasted
│ %4 = Base.Broadcast.broadcasted(%3, %2, %%x, %%x)::Any ││
│ %5 = %1(%4)::Any │
└── return %5
Notably, with
diff --git a/base/broadcast.jl b/base/broadcast.jl
index 25b4c93415..2e44e366d8 100644
--- a/base/broadcast.jl
+++ b/base/broadcast.jl
@@ -43,7 +43,6 @@ struct Style{T} <: BroadcastStyle end
BroadcastStyle(::Type{<:Tuple}) = Style{Tuple}()
struct Unknown <: BroadcastStyle end
-BroadcastStyle(::Type{Union{}}) = Unknown() # ambiguity resolution
"""
`Broadcast.AbstractArrayStyle{N} <: BroadcastStyle` is the abstract supertype for any style
diff --git a/stdlib/LinearAlgebra/src/structuredbroadcast.jl b/stdlib/LinearAlgebra/src/structuredbroadcast.jl
index fcd5c68d48..8fb94c9c2c 100644
--- a/stdlib/LinearAlgebra/src/structuredbroadcast.jl
+++ b/stdlib/LinearAlgebra/src/structuredbroadcast.jl
@@ -8,8 +8,11 @@ struct StructuredMatrixStyle{T} <: Broadcast.AbstractArrayStyle{2} end
StructuredMatrixStyle{T}(::Val{2}) where {T} = StructuredMatrixStyle{T}()
StructuredMatrixStyle{T}(::Val{N}) where {T,N} = Broadcast.DefaultArrayStyle{N}()
-const StructuredMatrix = Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular}
-Broadcast.BroadcastStyle(::Type{T}) where {T<:StructuredMatrix} = StructuredMatrixStyle{T}()
+const StructuredMatrix1 = Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}
+const StructuredMatrix2 = Union{LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular}
+const StructuredMatrix = Union{StructuredMatrix1,StructuredMatrix2}
+Broadcast.BroadcastStyle(::Type{T}) where {T<:StructuredMatrix1} = StructuredMatrixStyle{T}()
+Broadcast.BroadcastStyle(::Type{T}) where {T<:StructuredMatrix2} = StructuredMatrixStyle{T}()
# Promotion of broadcasts between structured matrices. This is slightly unusual
# as we define them symmetrically. This allows us to have a fallback to DefaultArrayStyle{2}().
one gets combine_styles
inferred, but still
julia> code_warntype(foo, (SparseVector,))
Body::Any
1 1 ─ %1 = :(Base.Broadcast.materialize)::Core.Compiler.Const(Base.Broadcast.materialize, false) │
│ %2 = :(Main.:+)::Core.Compiler.Const(+, false) │
│ %3 = Core.tuple(%%x, %%x)::Tuple{SparseArrays.SparseVector,SparseArrays.SparseVector} │╻ broadcasted
│ %4 = Base.Broadcast.Broadcasted{SparseArrays.HigherOrderFns.SparseVecStyle,Axes,F,Args} where Args<:Tuple where F where Axes(%2, %3)::Base.Broadcast.Broadcasted{SparseArrays.HigherOrderFns.SparseVecStyle,Nothing,typeof(+),_1} where _1
│ %5 = %1(%4)::Any │
└── return %5
I ran into this very simple inference failure affecting
Broadcast.combine_styles
, which happens when you try to broadcast over something which is inferred as non-leaftype (in this caseFoo
instead ofFoo{T}
.This has the effect that e.g.
Foo .+ Foo
gets inferred asAny
rather thanFoo
as it should (assuming the remaining pieces of code are implemented as such). I can say that my 0.6 version of this code worked this way (although w/o the new API its longer and probably not worth writing out here).Seeing
StructuredMatrixStyle
show up in the result was bizarre, so on a hunch I deleted line 12 from here: https://github.com/JuliaLang/julia/blob/f104ea4ec352519de76e52bf65ea7b3ed2dc6155/stdlib/LinearAlgebra/src/structuredbroadcast.jl#L10-L13 and this indeed made the inferred return type justUnion{Base.Broadcast.DefaultArrayStyle{0}, Base.Broadcast.Unknown}
. I'm guessing that theStructuredMatrix
union there is just too large and something is getting screwed up?As for why the
Base.Broadcast.Unknown
shows up, that one is beyond me, although maybe that's OK because that happens for other types too. This is with 565bd4d265 by the way.