JuliaML / MLUtils.jl

Utilities and abstractions for Machine Learning tasks
MIT License
107 stars 20 forks source link

The `chunk` function is not differentiable on GPU #170

Closed bicycle1885 closed 8 months ago

bicycle1885 commented 10 months ago

I found that operations involving the chunk function are not differentiable on GPU.

using CUDA, Flux

struct Model
    layers
end

function Model()
    dense = Dense(3 => 8)
    Model((;dense))
end

Flux.@functor Model

function (model::Model)(x)
    y = model.layers.dense(x)
    a, b = Flux.chunk(y, size = [4, 4], dims = 1)
    sum(a + b)
end

model = Model()
x = randn(Float32, 3, 10)

x, model = gpu((x, model))
@show model(x)
Flux.withgradient(model -> model(x), model)

When I try to run this, I see the following error:

ERROR: LoadError: MethodError: no method matching parent(::Type{SubArray{Union{ChainRulesCore.ZeroTangent, CuMatrix{Float32, CUDA.Mem.DeviceBuffer}, DenseCuMatrix{Float32, CUDA.Mem.DeviceBuffer}}, 0, Vector{Union{ChainRulesCore.ZeroTangent, CuMatrix{Float32, CUDA.Mem.DeviceBuffer}, DenseCuMatrix{Float32, CUDA.Mem.DeviceBuffer}}}, Tuple{Int64}, true}})

Closest candidates are:
  parent(!Matched::Union{LinearAlgebra.Adjoint{T, S}, LinearAlgebra.Transpose{T, S}} where {T, S})
   @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/adjtrans.jl:341
  parent(!Matched::Union{LinearAlgebra.LowerTriangular{T, S} where S<:AbstractMatrix{T}, LinearAlgebra.UnitLowerTriangular{T, S} where S<:AbstractMatrix{T}, LinearAlgebra.UnitUpperTriangular{T, S} where S<:AbstractMatrix{T}, LinearAlgebra.UpperTriangular{T, S} where S<:AbstractMatrix{T}} where T)
   @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/triangular.jl:164
  parent(!Matched::Union{LinearAlgebra.Hermitian{T, S}, LinearAlgebra.Symmetric{T, S}} where {T, S})
   @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/symmetric.jl:275
  ...

Full error message: log.txt

My environment is:

julia> versioninfo()
Julia Version 1.9.3
Commit bed2cd540a1 (2023-08-24 14:43 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: ]Linux (x86_64-linux-gnu)
  CPU: 16 × Intel(R) Xeon(R) Gold 6134 CPU @ 3.20GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, skylake-avx512)
  Threads: 1 on 16 virtual cores
Environment:
  JULIA_PROJECT = @.

(tmp) pkg> st
Status `~/tmp/Project.toml`
  [052768ef] CUDA v5.0.0
  [587475ba] Flux v0.14.6
  [02a925ec] cuDNN v1.2.0

Manifest.toml and Project.toml are the followings (the file name extensions are replaced for uploading). Manifest.txt Project.txt

bicycle1885 commented 10 months ago

I found that this might be caused by Zygote.jl 0.6.67 because the problem goes away when I downgrade Zygote.jl to 0.6.66.

mcabbott commented 10 months ago

This looks like a bug in the CR rrule being used here, after Zygote deleted its rule. Any chance you can isolate it further, e.g. to a single getindex call which gives a similar error?

bicycle1885 commented 10 months ago

I'm not sure what you expect. Can you elaborate on this? Then, I'll test it soon.

to a single getindex call which gives a similar error?

mcabbott commented 10 months ago

Sorry, what I mean is that chunk must end up doing some indexing, maybe like x[:,1], which uses the rule for getindex, which shows up in the stacktrace. I think the error can probably be reproduced by something like gradient(x -> sum(abs2, x[:, 2:3]), cu(rand(2,3))) but I'm not sure what the indices needed are.

This is the relevant bit of the stacktrace:

  [4] materialize!
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:46 [inlined]
  [5] materialize!
    @ ./broadcast.jl:881 [inlined]
  [6] ∇getindex!(dx::Vector{Union{ChainRulesCore.ZeroTangent, CuMatrix{Float32, CUDA.Mem.DeviceBuffer}, DenseCuMatrix{Float32, CUDA.Mem.DeviceBuffer}}}, dy::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, inds::Int64)
    @ ChainRules ~/.julia/packages/ChainRules/Tvwnx/src/rulesets/Base/indexing.jl:147
  [7] ∇getindex(x::Vector{SubArray{Float32, 2, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}, dy::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, inds::Int64)
    @ ChainRules ~/.julia/packages/ChainRules/Tvwnx/src/rulesets/Base/indexing.jl:89

Here Vector{SubArray{..., CuArray means perhaps something like x = collect(eachcol(cu(rand(2,3))) is being indexed.

bicycle1885 commented 10 months ago

Thanks. I've test two other cases that do more direct indexing without calling chunk. One is a, b = y[1:4,:], y[5:8,:] and the other is a, b = view(y, 1:4, :), view(y, 5:8, :), and I confirmed that both of them work without any error on GPU.

function (model::Model)(x)
    y = model.layers.dense(x)

    # ERROR: LoadError: MethodError: no method matching parent(...
    #a, b = Flux.chunk(y, size = [4, 4], dims = 1)

    # this works
    #a, b = y[1:4,:], y[5:8,:]

    # this works
    a, b = view(y, 1:4, :), view(y, 5:8, :)

    sum(a + b)
end
bicycle1885 commented 10 months ago

I discovered that the following pattern doesn't work. I guess the lowering implicitly inserts some getindex calls hindering differentiation.

    a, b = [y[1:4,:], y[5:8,:]]
bicycle1885 commented 10 months ago

So, I reduced the code to the following. As in the case above, the error disappears if I use Zygote.jl 0.6.66 instead of 0.6.67.

using CUDA, Zygote

function f(x)
    a, b = [x[1:4], x[5:8]]
    sum(a + b)
end

x = cu(randn(8))
@show f(x)
@show Zygote.gradient(f, x)
mcabbott commented 10 months ago

Thanks, that's helpful!

ToucheSir commented 10 months ago

Thanks for the MWE, will follow-up on the Zygote issue.

CarloLucibello commented 8 months ago

The example in the OP works fine on the latest version of the packages