FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.53k stars 608 forks source link

Type instability in Recur for 3 dimensional arrays #1947

Closed characat0 closed 2 years ago

characat0 commented 2 years ago

While working with LSTM I noticed that Recur is not type stable for 3 dimensional arrays.

If we execute:

using Flux
model = LSTM(1=>1)
x = rand(Float32, 1, 1, 1)
@code_warntype model(x)

Result:

MethodInstance for (::Flux.Recur{Flux.LSTMCell{Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}})(::Array{Float32, 3})
  from (m::Flux.Recur)(x::AbstractArray{T, 3}) where T in Flux at ~\.julia\packages\Flux\18YZE\src\layers\recurrent.jl:88
Static Parameters
  T = Float32
Arguments
  m::Flux.Recur{Flux.LSTMCell{Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}
  x::Array{Float32, 3}
Locals
  #269::Flux.var"#269#270"{Flux.Recur{Flux.LSTMCell{Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}}
  sze::Any
  h::Any
Body::Any
1 ─ %1  = %new(Flux.var"#269#270"{Flux.Recur{Flux.LSTMCell{Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}}, m)::Flux.var"#269#270"{Flux.Recur{Flux.LSTMCell{Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}}
│   %2  = invoke Base.var"#eachslice##kw"()($(QuoteNode((dims = 3,)))::NamedTuple{(:dims,), Tuple{Int64}}, Flux.eachslice::typeof(eachslice), x::Array{Float32, 3})::Base.Generator{Base.OneTo{Int64}}
│   %3  = Base.Generator(%1, %2)::Base.Generator{_A, Flux.var"#269#270"{Flux.Recur{Flux.LSTMCell{Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}}} where _A
│   %4  = Base.collect(%3)::Any
│   %5  = Base.getindex(%4, 1)::Any
│   %6  = Flux.size(%5)::Any
│   %7  = Flux.reduce(Flux.hcat, %4)::Any
│   %8  = Base.getindex(%6, 1)::Any
│   %9  = Base.getindex(%6, 2)::Any
│   %10 = Flux.length(%4)::Any
│   %11 = Flux.reshape(%7, %8, %9, %10)::Any
└──       return %11

The type instability is caused by a call to eachslice in: https://github.com/FluxML/Flux.jl/blob/f038cffa5c2b8cd5c3a4f46e90d188a94096f72e/src/layers/recurrent.jl#L89

This is related to https://github.com/JuliaLang/julia/issues/39639. A quick fix would be to replace eachslice with a view of slices of the array along the third dimension.

ToucheSir commented 2 years ago

What do you mean by a "view of slices of the array along the third dimension"?

characat0 commented 2 years ago

What do you mean by a "view of slices of the array along the third dimension"?

I was about to suggest going back to view(x, :, :, i) before I saw https://github.com/FluxML/Flux.jl/pull/1873, but now I am not sure if there is a way to make this type stable and performant.

mkschleg commented 2 years ago

Hmmm. Good catch. Eachslice is effectively doing those views with an efficient rrule (as far as I understand it). It is a shame eachslice is not type stable. I could see two work arounds.

1.) Do a type check after reduce and force typing on the return. 2.) Do a custom eachslice that is type stable and re-implement the rrule for this new eachslice.

1 is a fast hot fix, but 2 is likely what would be best in the long term. Unsure how feasible 2 is though and I'm swamped at work, so I've been spotty on my flux progress and won't be able to help too much.

Given the problem in eachslice from the issue in julia you linked back to, we could implement a specialized eachslice that always crawls the last dimension, which we should be able to infer from the type signature. I think this would be type-stable.

characat0 commented 2 years ago

I found the rrule defined here in Zygote.jl, however there is another definition in ChainRules.jl which takes advantage of Val to optimize it at compile time (and save a lot of allocations). I made a custom function called eachlastdim that returns an iterator over the last dimension and implemented the rrule for it, but I'm unsure if I should use the Zygote or ChainRules method since ChainRules is not directly imported in Flux.

ToucheSir commented 2 years ago

I wonder if deleting the Zygote adjoint would be enough, have you tried that?

characat0 commented 2 years ago

Currently, I'm using the ChainRules one, so we could safely delete the definition in Zygote. Let me know what do you think about the addition of eachlastdim and I will gladly make a PR in Zygote for deleting the old adjoints.

ToucheSir commented 2 years ago

Ideally we wouldn't need #1948/eachlastdim at all if the compiler is smart enough to make sense of the rrule in ChainRules. That's why I mentioned testing after deleting the Zygote adjoint: Zygote will automatically fall back to that rrule and you can see whether type stability is preserved then.

characat0 commented 2 years ago

After deleting the rrules in Zygote, I tested the following code:

using Zygote
x = rand(Float32, 1, 1, 1, 10);
f(x) = eachslice(x; dims=4);
y, back = Zygote.pullback(f, x);
@code_warntype back(y)

Result:

MethodInstance for (::Zygote.var"#56#57"{typeof(∂(f))})(::Vector{SubArray{Float32, 3, Array{Float32, 4}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Int64}, true}})
  from (::Zygote.var"#56#57")(Δ) in Zygote at ~\.julia\packages\Zygote\xEPQb\src\compiler\interface.jl:41
Arguments
  #self#::Zygote.var"#56#57"{typeof(∂(f))}
  Δ::Vector{SubArray{Float32, 3, Array{Float32, 4}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Int64}, true}}
Body::Tuple{Array{Float32, 4}}
1 ─ %1 = Core.getfield(#self#, :back)::typeof(∂(f))
│   %2 = (%1)(Δ)::Tuple{Nothing, Array{Float32, 4}}
│   %3 = Zygote.tailmemaybe(%2)::Tuple{Array{Float32, 4}}
└──      return %3

So the pullback of eachslice is type stable using ChainRules rrule, however, eachslice remains type unstable.

ToucheSir commented 2 years ago

Ok, it was worth a try! Let's continue this discussion to #1948.