FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

scalar getindex on gpu in the backwards pass #261

Closed jumerckx closed 5 years ago

jumerckx commented 5 years ago
using Zygote, CuArrays
CuArrays.allowscalar(false)

x, w = (cu(rand(10)), cu(rand(100,10)))

Zygote.gradient((w)->sum(w*x), w)

I'm not sure whether this is a Zygote or CuArrays issue. When running this code, it fails because the multiplication in the backwards pass uses a matmul that calls getindex.

Thanks

kshyatt commented 5 years ago

I looked at this a little and it seems the output vector for the matmul is being allocated on the CPU. The specific error you get is:

ERROR: scalar getindex is disallowed
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] assertscalar(::String) at .../.julia/dev/GPUArrays/src/indexing.jl:14
 [3] getindex at ... /.julia/dev/GPUArrays/src/indexing.jl:54 [inlined]
 [4] generic_matvecmul!(::Array{Float32,1}, ::Char, ::CuArray{Float32,2}, ::FillArrays.Fill{Float32,1,Tuple{Base.OneTo{Int64}}}, ::LinearAlgebra.MulAddMul{true,true,Bool,Bool}) at  ... julia/stdlib/v1.3/LinearAlgebra/src/matmul.jl:637
 [5] mul! at ... julia/usr/share/julia/stdlib/v1.3/LinearAlgebra/src/matmul.jl:129 [inlined]
 [6] mul! at ... julia/usr/share/julia/stdlib/v1.3/LinearAlgebra/src/matmul.jl:203 [inlined]
 [7] * at ... julia/usr/share/julia/stdlib/v1.3/LinearAlgebra/src/matmul.jl:113 [inlined]
 [8] #1025 at .../.julia/dev/Zygote/src/lib/array.jl:230 [inlined]
 [9] #2554#back at .../.julia/dev/ZygoteRules/src/adjoint.jl:49 [inlined]
 [10] #3 at ./REPL[4]:1 [inlined]
 [11] (::Zygote.var"#28#29"{typeof(∂(#3))})(::Float32) at .../.julia/dev/Zygote/src/compiler/interface.jl:38
 [12] gradient(::Function, ::CuArray{Float32,2}) at .../.julia/dev/Zygote/src/compiler/interface.jl:47

Some more digging shows that it's the A' * Δ that's the culprit. I split the return up a bit to generate:

typeof(A') = LinearAlgebra.Adjoint{Float32,CuArray{Float32,2}}
typeof(Δ) = FillArrays.Fill{Float32,1,Tuple{Base.OneTo{Int64}}}
ERROR: scalar getindex is disallowed
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] assertscalar(::String) at  .../.julia/dev/GPUArrays/src/indexing.jl:14
 [3] getindex at .../.julia/dev/GPUArrays/src/indexing.jl:54 [inlined]
 [4] generic_matvecmul!(::Array{Float32,1}, ::Char, ::CuArray{Float32,2}, ::FillArrays.Fill{Float32,1,Tuple{Base.OneTo{Int64}}}, ::LinearAlgebra.MulAddMul{true,true,Bool,Bool}) at ... julia/usr/share/julia/stdlib/v1.3/LinearAlgebra/src/matmul.jl:637

But:

ulia> using CuArrays

julia> A = cu(rand(5,5)); x = cu(rand(5));

julia> typeof(A')
LinearAlgebra.Adjoint{Float32,CuArray{Float32,2}}

julia> CuArrays.allowscalar(false)

julia> A'*x
5-element CuArray{Float32,1}:
 1.2826176 
 1.1691043 
 1.4008925 
 1.5686983 
 0.83618677

works no problem. If you do something very bad and change A' -> A in the adjoint, things still break:

@adjoint function(A::AbstractMatrix * x::AbstractVector)
    A*x, function (Δ)
        resa = Δ * x'
        @show typeof(A)
        @show typeof(Δ)
        resb = A * Δ
        return (resa, resb)
    end
    # don't be like me, kids
    #return A * x, Δ::AbstractVector->(Δ * x', A' * Δ)
end

then:

julia> Zygote.gradient((w)->sum(w*x), w)
typeof(A) = CuArray{Float32,2}
typeof(Δ) = FillArrays.Fill{Float32,1,Tuple{Base.OneTo{Int64}}}
ERROR: scalar getindex is disallowed
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] assertscalar(::String) at .../.julia/dev/GPUArrays/src/indexing.jl:14
 [3] getindex at .../.julia/dev/GPUArrays/src/indexing.jl:54 [inlined]
 [4] generic_matvecmul!(::Array{Float32,1}, ::Char, ::CuArray{Float32,2}, ::FillArrays.Fill{Float32,1,Tuple{Base.OneTo{Int64}}}, ::LinearAlgebra.MulAddMul{true,true,Bool,Bool}) at ... julia/usr/share/julia/stdlib/v1.3/LinearAlgebra/src/matmul.jl:651
 [5] mul! at ... julia/usr/share/julia/stdlib/v1.3/LinearAlgebra/src/matmul.jl:81 [inlined]
 [6] mul! at ... julia/usr/share/julia/stdlib/v1.3/LinearAlgebra/src/matmul.jl:203 [inlined]
 [7] * at .../julia/usr/share/julia/stdlib/v1.3/LinearAlgebra/src/matmul.jl:51 [inlined]

It seems as if there aren't matmul fallbacks for the FillArrays.Fill type to be multiplied with CuArrays?

kshyatt commented 5 years ago

OK, so I dug into this a bit more and I think the problem is actually that FillArrays.jl doesn't specialize the matmul and matvecmul methods, so A::Matrix{Float32} * b::FillArrays.Fill{Float32,1,Tuple{Base.OneTo{Int64}}} also uses generic_matvecmul!. On the CPU this isn't a problem, but for the GPU it seems as if we should either define a method, or convert to a dense representation. @MikeInnes thoughts?

kshyatt commented 5 years ago

Sorry, one last thing -- changing the @adjoint to:

@adjoint function(A::AbstractMatrix * x::AbstractVector)
    return A * x, Δ::AbstractVector->(Δ * x', A' * convert(typeof(x), Array(Δ)))
end

allows the gradient to be computed without erroring out, although it's ugly and probably pretty slow. The gradient is currently broken also if x is a CuMatrix, for the same reason.

With my modified gradient for matrix * vector, on CPU, I get:

julia> @time Zygote.gradient((w)->sum(w*x), w)
  0.037545 seconds (40.00 k allocations: 2.029 MiB)

with the original version falling back to genericmatvecmul:

julia> @time Zygote.gradient((w)->sum(w*x), w)
  0.038051 seconds (39.08 k allocations: 1.987 MiB)

Could scale worse as the matrices/vectors get bigger, but for now doesn't seem too bad? For a bigger example with the new method (still on CPU:)

julia> x, w = (rand(10_000), rand(10_000,10_000));
julia> @time Zygote.gradient((w)->sum(w*x), w)
  0.551745 seconds (40.00 k allocations: 765.189 MiB)

For the fallback:

julia> @time Zygote.gradient((w)->sum(w*x), w)
  0.561201 seconds (39.07 k allocations: 765.072 MiB, 1.58% gc time)

So, still pretty similar...

willtebbutt commented 5 years ago

@kshyatt it would be great if you could add AbstractMatrix * Fill and related methods to FillArrays -- this is definitely some structure that should be exploited and would be generally helpful. I'm not really sure why we haven't done this already to be honest...

kshyatt commented 5 years ago

@willtebbutt is something like what I posted up there an OK start? I'm worried trying to do something more fancy by scaling and summing the rows of the matrix will hit these getindex issues again

willtebbutt commented 5 years ago

The above certainly looks like a good fix for now :) We probably want to avoid the extra temporary in the case that Δ is already a Vector, but other than that I can't see any problems.

It would be nice to solve this properly at some point though, as we're presumably suffering from the same problem in other pullbacks e.g. for AbstractMatrix * AbstractMatrix.

To be concrete, in the matrix-vector case it would be something like

using CuArrays
CuArrays.allowscalar(false)
A = cu(randn(1000, 999))
x_value =  1.3
reshape(sum(A; dims=2) .* x_value, 1000)

which is just composed of sum and broadcasting a scalar as you suggested. It appears to work as hoped. However, if you try something like sum(transpose(A); dims=2) indexing issues arise, also as you suggested.

kshyatt commented 5 years ago

Perhaps we can do the fast thing for untransposed and unadjointed arrays, and fall back to the dumb thing for transpose/adjoint, in the hope of coming up with a good solution later?

kshyatt commented 5 years ago

Actually, I think there is a way to do it quickly for transpose and adjoint too

kshyatt commented 5 years ago

@merckxiaan can you confirm on your end this is fixed? I think it should be but want to make sure.

jumerckx commented 5 years ago

Thanks a lot for having taken the time to look into this.

Chances are I'm doing something wrong, but I'm still seeing the error. I'm on Zygote v0.3.4 #master, ZygoteRules v0.2.0 #master, CuArrays v1.3.0 #master and IRTools v0.2.3 #master. EDIT: When I try to add FillArrays#master I get ERROR: Unsatisfiable requirements detected for package FillArrays [1a297f60]: ... due to Zygote

Here's my versioninfo():

Julia Version 1.1.0
Commit 80516ca202 (2019-01-21 21:24 UTC)
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: Intel(R) Core(TM) i7-6700HQ CPU @ 2.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-6.0.1 (ORCJIT, skylake)
Environment:
  JULIA_EDITOR = "C:\Users\jules\AppData\Local\atom\app-1.40.1\atom.exe"  -a
  JULIA_NUM_THREADS = 4
DhairyaLGandhi commented 5 years ago

I was going to add some adjoints around CuArrays and matmul ops for coverage (and because we'd need it for Flux), specifically by specializing adjoints involving FiillArrays and CuArrays to allocate on the GPU. Is this likely to make it a non issue?

kshyatt commented 5 years ago

@merckxiaan it might be the GPUArrays thing we merged yesterday, letting you use latest FillArrays. Try checking out GPUArrays#master and FillArrays#master?

jumerckx commented 5 years ago

That did the trick! Huge thanks for this!!!

Should I close the issue?

kshyatt commented 5 years ago

If you feel it's been resolved, sure :)