Closed characat0 closed 2 years ago
What do you mean by a "view of slices of the array along the third dimension"?
What do you mean by a "view of slices of the array along the third dimension"?
I was about to suggest going back to view(x, :, :, i)
before I saw https://github.com/FluxML/Flux.jl/pull/1873, but now I am not sure if there is a way to make this type stable and performant.
Hmmm. Good catch. Eachslice is effectively doing those views with an efficient rrule (as far as I understand it). It is a shame eachslice is not type stable. I could see two work arounds.
1.) Do a type check after reduce and force typing on the return. 2.) Do a custom eachslice that is type stable and re-implement the rrule for this new eachslice.
1 is a fast hot fix, but 2 is likely what would be best in the long term. Unsure how feasible 2 is though and I'm swamped at work, so I've been spotty on my flux progress and won't be able to help too much.
Given the problem in eachslice from the issue in julia you linked back to, we could implement a specialized eachslice that always crawls the last dimension, which we should be able to infer from the type signature. I think this would be type-stable.
I found the rrule defined here in Zygote.jl, however there is another definition in ChainRules.jl which takes advantage of Val to optimize it at compile time (and save a lot of allocations).
I made a custom function called eachlastdim
that returns an iterator over the last dimension and implemented the rrule for it, but I'm unsure if I should use the Zygote or ChainRules method since ChainRules is not directly imported in Flux.
I wonder if deleting the Zygote adjoint would be enough, have you tried that?
Currently, I'm using the ChainRules one, so we could safely delete the definition in Zygote. Let me know what do you think about the addition of eachlastdim and I will gladly make a PR in Zygote for deleting the old adjoints.
Ideally we wouldn't need #1948/eachlastdim
at all if the compiler is smart enough to make sense of the rrule in ChainRules. That's why I mentioned testing after deleting the Zygote adjoint: Zygote will automatically fall back to that rrule and you can see whether type stability is preserved then.
After deleting the rrules in Zygote, I tested the following code:
using Zygote
x = rand(Float32, 1, 1, 1, 10);
f(x) = eachslice(x; dims=4);
y, back = Zygote.pullback(f, x);
@code_warntype back(y)
Result:
MethodInstance for (::Zygote.var"#56#57"{typeof(∂(f))})(::Vector{SubArray{Float32, 3, Array{Float32, 4}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Int64}, true}})
from (::Zygote.var"#56#57")(Δ) in Zygote at ~\.julia\packages\Zygote\xEPQb\src\compiler\interface.jl:41
Arguments
#self#::Zygote.var"#56#57"{typeof(∂(f))}
Δ::Vector{SubArray{Float32, 3, Array{Float32, 4}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Int64}, true}}
Body::Tuple{Array{Float32, 4}}
1 ─ %1 = Core.getfield(#self#, :back)::typeof(∂(f))
│ %2 = (%1)(Δ)::Tuple{Nothing, Array{Float32, 4}}
│ %3 = Zygote.tailmemaybe(%2)::Tuple{Array{Float32, 4}}
└── return %3
So the pullback of eachslice
is type stable using ChainRules rrule, however, eachslice
remains type unstable.
Ok, it was worth a try! Let's continue this discussion to #1948.
While working with LSTM I noticed that Recur is not type stable for 3 dimensional arrays.
If we execute:
Result:
The type instability is caused by a call to
eachslice
in: https://github.com/FluxML/Flux.jl/blob/f038cffa5c2b8cd5c3a4f46e90d188a94096f72e/src/layers/recurrent.jl#L89This is related to https://github.com/JuliaLang/julia/issues/39639. A quick fix would be to replace
eachslice
with a view of slices of the array along the third dimension.