JuliaLang / julia

The Julia Programming Language
https://julialang.org/
MIT License
45.09k stars 5.43k forks source link

Broadcast.combine_styles inference failure in 0.7 #28181

Open marius311 opened 6 years ago

marius311 commented 6 years ago

I ran into this very simple inference failure affecting Broadcast.combine_styles, which happens when you try to broadcast over something which is inferred as non-leaftype (in this case Foo instead of Foo{T}.

julia> struct Foo{T} end

julia> Broadcast.BroadcastStyle(::Type{<:Foo}) = Broadcast.DefaultArrayStyle{0}()

julia> code_warntype(Broadcast.combine_styles, (Foo, ))
Body::Union{Base.Broadcast.DefaultArrayStyle{0}, Base.Broadcast.Unknown, LinearAlgebra.StructuredMatrixStyle{_1} where _1}
390 1 ─ %1  = Base.Broadcast.typeof(%%c)::Type{#s55} where #s55<:Foo{T} where T                                                 │
    │   %2  = isa(%1, Type{Union{}})::Bool                                                                                      │
    └──       goto 3 if not %2                                                                                                  │
    2 ─       goto 4                                                                                                            │
    3 ─ %5  = Base.Broadcast.BroadcastStyle(%1)::Union{Base.Broadcast.DefaultArrayStyle{0}, Base.Broadcast.Unknown, LinearAlgebra.StructuredMatrixStyle{_1} where _1}
    └──       goto 4                                                                                                            │
    4 ┄ %7  = φ (2 => :($(QuoteNode(Base.Broadcast.Unknown()))), 3 => %5)::Union{Base.Broadcast.DefaultArrayStyle{0}, Base.Broadcast.Unknown, LinearAlgebra.StructuredMatrixStyle{_1} where _1}
    │   %8  = isa(%7, Base.Broadcast.DefaultArrayStyle{0})::Bool                                                                │
    └──       goto 6 if not %8                                                                                                  │
    5 ─       goto 9                                                                                                            │
    6 ─ %11 = isa(%7, Base.Broadcast.Unknown)::Bool                                                                             │
    └──       goto 8 if not %11                                                                                                 │
    7 ─       goto 9                                                                                                            │
    8 ─ %14 = Base.Broadcast.result_style(%7)::Union{Base.Broadcast.DefaultArrayStyle{0}, Base.Broadcast.Unknown, LinearAlgebra.StructuredMatrixStyle{_1} where _1}
    └──       goto 9                                                                                                            │
    9 ┄ %16 = φ (5 => :($(QuoteNode(Base.Broadcast.DefaultArrayStyle{0}()))), 7 => :($(QuoteNode(Base.Broadcast.Unknown()))), 8 => %14)::Union{Base.Broadcast.DefaultArrayStyle{0}, Base.Broadcast.Unknown, LinearAlgebra.StructuredMatrixStyle{_1} where _1}
    └──       return %16                                                                                                        │

This has the effect that e.g. Foo .+ Foo gets inferred as Any rather than Foo as it should (assuming the remaining pieces of code are implemented as such). I can say that my 0.6 version of this code worked this way (although w/o the new API its longer and probably not worth writing out here).

Seeing StructuredMatrixStyle show up in the result was bizarre, so on a hunch I deleted line 12 from here: https://github.com/JuliaLang/julia/blob/f104ea4ec352519de76e52bf65ea7b3ed2dc6155/stdlib/LinearAlgebra/src/structuredbroadcast.jl#L10-L13 and this indeed made the inferred return type just Union{Base.Broadcast.DefaultArrayStyle{0}, Base.Broadcast.Unknown}. I'm guessing that the StructuredMatrix union there is just too large and something is getting screwed up?

As for why the Base.Broadcast.Unknown shows up, that one is beyond me, although maybe that's OK because that happens for other types too. This is with 565bd4d265 by the way.

martinholters commented 6 years ago
julia> methods(Base.Broadcast.BroadcastStyle, Tuple{Type{<:Foo}})
# 12 methods for generic function "(::Type)":
[1] Base.Broadcast.BroadcastStyle(::Type{Union{}}) in Base.Broadcast at broadcast.jl:46
[2] Base.Broadcast.BroadcastStyle(::Type{#s608} where #s608<:Tuple) in Base.Broadcast at broadcast.jl:43
[3] Base.Broadcast.BroadcastStyle(::Type{T}) where T<:Union{Bidiagonal, Diagonal, LowerTriangular, SymTridiagonal, Tridiagonal, UnitLowerTriangular, UnitUpperTriangular, UpperTriangular} in LinearAlgebra at /home/hol/Projects/julia-master/usr/share/julia/stdlib/v0.7/LinearAlgebra/src/structuredbroadcast.jl:12
[4] Base.Broadcast.BroadcastStyle(::Type{#s608} where #s608<:SparseVector) in SparseArrays.HigherOrderFns at /home/hol/Projects/julia-master/usr/share/julia/stdlib/v0.7/SparseArrays/src/higherorderfns.jl:40
[5] Base.Broadcast.BroadcastStyle(::Type{#s608} where #s608<:SparseMatrixCSC) in SparseArrays.HigherOrderFns at /home/hol/Projects/julia-master/usr/share/julia/stdlib/v0.7/SparseArrays/src/higherorderfns.jl:41
[6] Base.Broadcast.BroadcastStyle(::Type{#s608} where #s608<:(Adjoint{T,#s607} where #s607<:Union{SparseMatrixCSC, SparseVector} where T)) in SparseArrays.HigherOrderFns at /home/hol/Projects/julia-master/usr/share/julia/stdlib/v0.7/SparseArrays/src/higherorderfns.jl:69
[7] Base.Broadcast.BroadcastStyle(::Type{#s608} where #s608<:(Transpose{T,#s607} where #s607<:Union{SparseMatrixCSC, SparseVector} where T)) in SparseArrays.HigherOrderFns at /home/hol/Projects/julia-master/usr/share/julia/stdlib/v0.7/SparseArrays/src/higherorderfns.jl:70
[8] Base.Broadcast.BroadcastStyle(::Type{#s608} where #s608<:(RowVector{T,#s607} where #s607<:(Array{T,1} where T))) where T in SparseArrays.HigherOrderFns at /home/hol/Projects/julia-master/usr/share/julia/stdlib/v0.7/SparseArrays/src/deprecated.jl:214
[9] Base.Broadcast.BroadcastStyle(::Type{#s608} where #s608<:AbstractArray{T,N}) where {T, N} in Base.Broadcast at broadcast.jl:99
[10] Base.Broadcast.BroadcastStyle(::Type{#s608} where #s608<:(Base.Broadcast.Broadcasted{S,Axes,F,Args} where Args<:Tuple where F where Axes)) where S<:Union{Base.Broadcast.Unknown, Nothing} in Base.Broadcast at broadcast.jl:223
[11] Base.Broadcast.BroadcastStyle(::Type{#s608} where #s608<:(Base.Broadcast.Broadcasted{Style,Axes,F,Args} where Args<:Tuple where F where Axes)) where Style in Base.Broadcast at broadcast.jl:222
[12] Base.Broadcast.BroadcastStyle(::Type{#s2} where #s2<:Foo) in Main at REPL[76]:1

Type{Union{}} is the common intersection.

marius311 commented 6 years ago

I don't completely follow, sorry. Do you mean you think this is the correct behavior? Are you sure you didn't mean,

julia> methods(Base.Broadcast.BroadcastStyle, Tuple{Type{Foo}})
# 1 method for generic function "(::Type)":
[1] Base.Broadcast.BroadcastStyle(::Type{#s1} where #s1<:Foo) in Main at REPL[2]:1

which gives the correct one? I don't totally see how any of methods above except [12] can be callable on this type here.

martinholters commented 6 years ago

You provide that c isa Foo, so typeof(c) is inferred as Type{<:Foo} (Type{#s55} where #s55<:Foo{T} where T as the type of %1). Note that typeof(c) cannot be Type{Foo}, as Foo is not a leaftype. But Type{Union{}} is special-cased anyway, so something else is weird here. Also note

julia> Base.return_types(Base.Broadcast.BroadcastStyle, Tuple{Type{<:Foo}})
12-element Array{Any,1}:
 Base.Broadcast.Unknown                       
 Base.Broadcast.Style{Tuple}                  
 LinearAlgebra.StructuredMatrixStyle{Union{}} 
 SparseArrays.HigherOrderFns.SparseVecStyle   
 SparseArrays.HigherOrderFns.SparseMatStyle   
 SparseArrays.HigherOrderFns.PromoteToSparse  
 SparseArrays.HigherOrderFns.PromoteToSparse  
 Base.Broadcast.DefaultArrayStyle{2}          
 Base.Broadcast.DefaultArrayStyle{_1} where _1
 Union{}                                      
 Any                                          
 Base.Broadcast.DefaultArrayStyle{0}

So only a subset of the methods is considered (probably due to the special-casing of Type{Union{}}), but at the same time, LinearAlgebra.StructuredMatrixStyle{Union{}} becomes LinearAlgebra.StructuredMatrixStyle (instead of being removed altogether, too). My wild guess would be something is sub-optimal when inference determines the applicable methods in combination with the Union upper bound.

marius311 commented 6 years ago

Thanks, that makes slightly more sense. But maybe I don't actually understand what methods and return_types give. To me it seems if c isa Foo, then the the only method which could possibly be called in line 3 is [12], not any of the others, no? So why are the other 11 returned there?

martinholters commented 6 years ago

Type{<:Foo} does not necessarily need to be obtained by typeof(x) where x isa Foo. And it could contain Type{Union{}}, which all those methods match. But I think that might be a red herring, anyway. This reduction might be more approachable:

julia> foo(::Type{<:Union{AbstractFloat, AbstractIrrational, Rational, Signed}}) = Val(:union4)
foo (generic function with 1 method)

julia> foo(::Type{<:Union{AbstractFloat, AbstractIrrational, Rational, Signed, Unsigned}}) = Val(:union5)
foo (generic function with 2 methods)

julia> foo(::Type{<:Ref}) = Val(:ref)
foo (generic function with 3 methods)

julia> bar(x) = foo(typeof(x))
bar (generic function with 1 method)

julia> code_warntype(bar, Tuple{Ref})
Body::Union{Val{:union5}, Val{:ref}}
1 1 ─ %1 = Main.typeof(%%x)::Type{#s55} where #s55<:Ref{T} where T                                                            │
  │   %2 = Main.foo(%1)::Union{Val{:union5}, Val{:ref}}                                                                       │
  └──      return %2 

So something breaks for a union with five or more members. Notably, defining the correct method before the offending one, things work as desired:

# fresh session
julia> foo(::Type{<:Union{AbstractFloat, AbstractIrrational, Rational, Signed}}) = Val(:union4)
foo (generic function with 1 method)

julia> foo(::Type{<:Ref}) = Val(:ref)
foo (generic function with 2 methods)

julia> foo(::Type{<:Union{AbstractFloat, AbstractIrrational, Rational, Signed, Unsigned}}) = Val(:union5)
foo (generic function with 3 methods)

julia> bar(x) = foo(typeof(x))
bar (generic function with 1 method)

julia> code_warntype(bar, Tuple{Ref})
Body::Val{:ref}
1 1 ─     return :($(QuoteNode(Val{:ref}())))

So there is something up with inference here.

That said, you're better off making sure the arguments to the broadcast can be inferred anyway :stuck_out_tongue:

marius311 commented 6 years ago

Ah nice, that reduction is great, that must be what's happening. And thanks for sticking with me with the explanations.

That said, you're better off making sure the arguments to the broadcast can be inferred anyway

I partially agree, but I do think there's scenarios where its not worth complexifying your code to keep things perfectly type stable, as long as the performance hit is negligible. That was the case for me here in 0.6, largely because the result of this computation (Foo .+ Foo) fed into something for which inference was able to narrow the type down to something concrete as long as the result was inferred as Foo and not as Any as is happening here.

martinholters commented 6 years ago

It's also happening for other types, btw:

julia> foo(x) = x .+ x
foo (generic function with 1 method)

julia> code_warntype(foo, (SparseVector,))
Body::Any
1 1 ─ %1 = :(Base.Broadcast.materialize)::Core.Compiler.Const(Base.Broadcast.materialize, false)                 │ 
  │   %2 = :(Main.:+)::Core.Compiler.Const(+, false)                                                             │ 
  │   %3 = Base.Broadcast.combine_styles(%%x, %%x)::Any                                                          │╻ broadcasted
  │   %4 = Base.Broadcast.broadcasted(%3, %2, %%x, %%x)::Any                                                     ││
  │   %5 = %1(%4)::Any                                                                                           │ 
  └──      return %5   

Notably, with

diff --git a/base/broadcast.jl b/base/broadcast.jl
index 25b4c93415..2e44e366d8 100644
--- a/base/broadcast.jl
+++ b/base/broadcast.jl
@@ -43,7 +43,6 @@ struct Style{T} <: BroadcastStyle end
 BroadcastStyle(::Type{<:Tuple}) = Style{Tuple}()

 struct Unknown <: BroadcastStyle end
-BroadcastStyle(::Type{Union{}}) = Unknown()  # ambiguity resolution

 """
 `Broadcast.AbstractArrayStyle{N} <: BroadcastStyle` is the abstract supertype for any style
diff --git a/stdlib/LinearAlgebra/src/structuredbroadcast.jl b/stdlib/LinearAlgebra/src/structuredbroadcast.jl
index fcd5c68d48..8fb94c9c2c 100644
--- a/stdlib/LinearAlgebra/src/structuredbroadcast.jl
+++ b/stdlib/LinearAlgebra/src/structuredbroadcast.jl
@@ -8,8 +8,11 @@ struct StructuredMatrixStyle{T} <: Broadcast.AbstractArrayStyle{2} end
 StructuredMatrixStyle{T}(::Val{2}) where {T} = StructuredMatrixStyle{T}()
 StructuredMatrixStyle{T}(::Val{N}) where {T,N} = Broadcast.DefaultArrayStyle{N}()

-const StructuredMatrix = Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular}
-Broadcast.BroadcastStyle(::Type{T}) where {T<:StructuredMatrix} = StructuredMatrixStyle{T}()
+const StructuredMatrix1 = Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}
+const StructuredMatrix2 = Union{LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular}
+const StructuredMatrix = Union{StructuredMatrix1,StructuredMatrix2}
+Broadcast.BroadcastStyle(::Type{T}) where {T<:StructuredMatrix1} = StructuredMatrixStyle{T}()
+Broadcast.BroadcastStyle(::Type{T}) where {T<:StructuredMatrix2} = StructuredMatrixStyle{T}()

 # Promotion of broadcasts between structured matrices. This is slightly unusual
 # as we define them symmetrically. This allows us to have a fallback to DefaultArrayStyle{2}().

one gets combine_styles inferred, but still

julia> code_warntype(foo, (SparseVector,))
Body::Any
1 1 ─ %1 = :(Base.Broadcast.materialize)::Core.Compiler.Const(Base.Broadcast.materialize, false)               │  
  │   %2 = :(Main.:+)::Core.Compiler.Const(+, false)                                                           │  
  │   %3 = Core.tuple(%%x, %%x)::Tuple{SparseArrays.SparseVector,SparseArrays.SparseVector}                    │╻  broadcasted
  │   %4 = Base.Broadcast.Broadcasted{SparseArrays.HigherOrderFns.SparseVecStyle,Axes,F,Args} where Args<:Tuple where F where Axes(%2, %3)::Base.Broadcast.Broadcasted{SparseArrays.HigherOrderFns.SparseVecStyle,Nothing,typeof(+),_1} where _1
  │   %5 = %1(%4)::Any                                                                                         │  
  └──      return %5