EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
461 stars 66 forks source link

CUDA error in new Enzyme version #1930

Closed jakubMitura14 closed 2 months ago

jakubMitura14 commented 2 months ago

Hello I had loaded latest version of Enzyme and tried Cuda test case and it gives error

RTX 3090 Enzyme version: 0.13.6 CUDA version: 5.5.2

using CUDA, Enzyme, Test
function mul_kernel(A)
  i = threadIdx().x
  if i <= length(A)
      A[i] *= A[i]
  end
  return nothing
end

function grad_mul_kernel(A, dA)
  autodiff_deferred(Reverse, Const(mul_kernel), Const, Duplicated(A, dA))
  return nothing
end

A = CUDA.ones(64,)
@cuda threads=length(A) mul_kernel(A)
A = CUDA.ones(64,)
dA = similar(A)
dA .= 1
@cuda threads=length(A) grad_mul_kernel(A, dA)
all(dA .== 2)

and it give error

julia> all(dA .== 2)
ERROR: a exception was thrown during kernel execution on thread (1, 1, 1) in block (1, 1, 1).
Stacktrace:
 [1] error at ./error.jl:35
 [2] autodiff_deferred at /usr/local/share/julia/packages/Enzyme/xD7hH/src/Enzyme.jl:692
 [3] grad_mul_kernel at ./REPL[3]:2

ERROR: KernelException: exception thrown during kernel execution on device NVIDIA GeForce RTX 3090
Stacktrace:
  [1] check_exceptions()
    @ CUDA /usr/local/share/julia/packages/CUDA/2kjXI/src/compiler/exceptions.jl:39
  [2] device_synchronize(; blocking::Bool, spin::Bool)
    @ CUDA /usr/local/share/julia/packages/CUDA/2kjXI/lib/cudadrv/synchronization.jl:191
  [3] device_synchronize
    @ /usr/local/share/julia/packages/CUDA/2kjXI/lib/cudadrv/synchronization.jl:178 [inlined]
  [4] checked_cuModuleLoadDataEx(_module::Base.RefValue{…}, image::Ptr{…}, numOptions::Int64, options::Vector{…}, optionValues::Vector{…})
    @ CUDA /usr/local/share/julia/packages/CUDA/2kjXI/lib/cudadrv/module.jl:18
  [5] CuModule(data::Vector{UInt8}, options::Dict{CUDA.CUjit_option_enum, Any})
    @ CUDA /usr/local/share/julia/packages/CUDA/2kjXI/lib/cudadrv/module.jl:60
  [6] CuModule
    @ /usr/local/share/julia/packages/CUDA/2kjXI/lib/cudadrv/module.jl:49 [inlined]
  [7] link(job::GPUCompiler.CompilerJob, compiled::@NamedTuple{image::Vector{UInt8}, entry::String})
    @ CUDA /usr/local/share/julia/packages/CUDA/2kjXI/src/compiler/compilation.jl:409
  [8] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))
    @ GPUCompiler /usr/local/share/julia/packages/GPUCompiler/2CW9L/src/execution.jl:262
  [9] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
    @ GPUCompiler /usr/local/share/julia/packages/GPUCompiler/2CW9L/src/execution.jl:151
 [10] macro expansion
    @ /usr/local/share/julia/packages/CUDA/2kjXI/src/compiler/execution.jl:380 [inlined]
 [11] macro expansion
    @ ./lock.jl:267 [inlined]
 [12] cufunction(f::GPUArrays.var"#34#36", tt::Type{Tuple{…}}; kwargs::@Kwargs{})
    @ CUDA /usr/local/share/julia/packages/CUDA/2kjXI/src/compiler/execution.jl:375
 [13] cufunction
    @ /usr/local/share/julia/packages/CUDA/2kjXI/src/compiler/execution.jl:372 [inlined]
 [14] macro expansion
    @ /usr/local/share/julia/packages/CUDA/2kjXI/src/compiler/execution.jl:112 [inlined]
 [15] #launch_heuristic#1200
    @ /usr/local/share/julia/packages/CUDA/2kjXI/src/gpuarrays.jl:17 [inlined]
 [16] launch_heuristic
    @ /usr/local/share/julia/packages/CUDA/2kjXI/src/gpuarrays.jl:15 [inlined]
 [17] _copyto!
    @ /usr/local/share/julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:78 [inlined]
 [18] copyto!
    @ /usr/local/share/julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:44 [inlined]
 [19] copy
    @ /usr/local/share/julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:29 [inlined]
 [20] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.DeviceMemory}, Nothing, typeof(==), Tuple{CuArray{…}, Int64}})
    @ Base.Broadcast ./broadcast.jl:903
 [21] top-level scope
    @ REPL[10]:1
Some type information was truncated. Use `show(err)` to see complete types.
wsmoses commented 2 months ago

Fixed by https://github.com/EnzymeAD/Enzyme.jl/pull/1931