LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
487 stars 58 forks source link

batched_jacobian + CUDA => InvalidIRError #636

Closed aksuhton closed 4 months ago

aksuhton commented 4 months ago

I'm getting an InvalidIRError when using batched_jacobian with CUDA if the feature array takes on certain sizes. Below is both an MWE and stacktrace. Thank you again for working with me so much on Jacobians!

MWE:

##
using Pkg
Pkg.add(["ADTypes", "Zygote", "ForwardDiff"])
Pkg.add(["Random", "LinearAlgebra"])
Pkg.add(["Lux", "LuxCUDA"])
##
using Lux, LuxCUDA
using ADTypes, Zygote, ForwardDiff
using Random, LinearAlgebra
##
function test_forward(N::Int; b::Int = 3, dev = gpu_device())
    model = @compact(; potential=Dense(N => N, gelu)) do x
        jac_pot = batched_jacobian(potential, AutoForwardDiff(), x)
        return jac_pot, potential.st
    end
    ps, st = Lux.setup(Random.default_rng(), model) .|> dev
    x = randn(Float32, N, b) |> dev
    m_x, st_ = model(x, ps, st);
end
##
test_forward(11) 
# passes for N in 1:10, fails for N = 11, 
# passes for N = 50, fails for N = 51

Stacktrace:

1-element ExceptionStack:
LoadError: InvalidIRError: compiling MethodInstance for (::GPUArrays.var"#34#36")(::CUDA.CuKernelContext, ::CuDeviceVector{ForwardDiff.Partials{11, Float32}, 1}, ::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.Mem.DeviceBuffer}, Tuple{Base.OneTo{Int64}}, LuxForwardDiffExt.var"#30#37"{11, ForwardDiff.Partials{11, Float32}, Float32}, Tuple{Base.Broadcast.Extruded{CuDeviceVector{Int64, 1}, Tuple{Bool}, Tuple{Int64}}}}, ::Int64) resulted in invalid LLVM IR
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)
Stacktrace:
 [1] _broadcast_getindex_evalf
   @ ./broadcast.jl:709
 [2] _broadcast_getindex
   @ ./broadcast.jl:682
 [3] getindex
   @ ./broadcast.jl:636
 [4] #34
   @ ~/.julia/packages/GPUArrays/OKkAu/src/host/broadcast.jl:59
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)
Stacktrace:
 [1] _broadcast_getindex_evalf
   @ ./broadcast.jl:709
 [2] _broadcast_getindex
   @ ./broadcast.jl:682
 [3] getindex
   @ ./broadcast.jl:636
 [4] #34
   @ ~/.julia/packages/GPUArrays/OKkAu/src/host/broadcast.jl:59
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)
Stacktrace:
 [1] _broadcast_getindex_evalf
   @ ./broadcast.jl:709
 [2] _broadcast_getindex
   @ ./broadcast.jl:682
 [3] getindex
   @ ./broadcast.jl:636
 [4] #34
   @ ~/.julia/packages/GPUArrays/OKkAu/src/host/broadcast.jl:59
Reason: unsupported dynamic function invocation (call to ForwardDiff.Partials{11, Float32})
Stacktrace:
 [1] #30
   @ ~/.julia/packages/Lux/SBzKQ/ext/LuxForwardDiffExt/batched_ad.jl:99
 [2] _broadcast_getindex_evalf
   @ ./broadcast.jl:709
 [3] _broadcast_getindex
   @ ./broadcast.jl:682
 [4] getindex
   @ ./broadcast.jl:636
 [5] #34
   @ ~/.julia/packages/GPUArrays/OKkAu/src/host/broadcast.jl:59
Reason: unsupported call to an unknown function (call to julia.pop_gc_frame)
Stacktrace:
 [1] _broadcast_getindex_evalf
   @ ./broadcast.jl:709
 [2] _broadcast_getindex
   @ ./broadcast.jl:682
 [3] getindex
   @ ./broadcast.jl:636
 [4] #34
   @ ~/.julia/packages/GPUArrays/OKkAu/src/host/broadcast.jl:59
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)
Stacktrace:
 [1] ntuple
   @ ./ntuple.jl:19
 [2] #30
   @ ~/.julia/packages/Lux/SBzKQ/ext/LuxForwardDiffExt/batched_ad.jl:99
 [3] _broadcast_getindex_evalf
   @ ./broadcast.jl:709
 [4] _broadcast_getindex
   @ ./broadcast.jl:682
 [5] getindex
   @ ./broadcast.jl:636
 [6] #34
   @ ~/.julia/packages/GPUArrays/OKkAu/src/host/broadcast.jl:59
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)
Stacktrace:
 [1] ntuple
   @ ./ntuple.jl:19
 [2] #30
   @ ~/.julia/packages/Lux/SBzKQ/ext/LuxForwardDiffExt/batched_ad.jl:99
 [3] _broadcast_getindex_evalf
   @ ./broadcast.jl:709
 [4] _broadcast_getindex
   @ ./broadcast.jl:682
 [5] getindex
   @ ./broadcast.jl:636
 [6] #34
   @ ~/.julia/packages/GPUArrays/OKkAu/src/host/broadcast.jl:59
Reason: unsupported call through a literal pointer (call to ijl_alloc_array_1d)
Stacktrace:
  [1] Array
    @ ./boot.jl:477
  [2] Array
    @ ./boot.jl:486
  [3] similar
    @ ./abstractarray.jl:877
  [4] similar
    @ ./abstractarray.jl:876
  [5] _array_for
    @ ./array.jl:723
  [6] collect
    @ ./array.jl:836
  [7] _ntuple
    @ ./ntuple.jl:37
  [8] ntuple
    @ ./ntuple.jl:19
  [9] #30
    @ ~/.julia/packages/Lux/SBzKQ/ext/LuxForwardDiffExt/batched_ad.jl:99
 [10] _broadcast_getindex_evalf
    @ ./broadcast.jl:709
 [11] _broadcast_getindex
    @ ./broadcast.jl:682
 [12] getindex
    @ ./broadcast.jl:636
 [13] #34
    @ ~/.julia/packages/GPUArrays/OKkAu/src/host/broadcast.jl:59
Reason: unsupported call through a literal pointer (call to ijl_alloc_array_1d)
Stacktrace:
  [1] Array
    @ ./boot.jl:477
  [2] Array
    @ ./boot.jl:486
  [3] similar
    @ ./abstractarray.jl:877
  [4] similar
    @ ./abstractarray.jl:876
  [5] _array_for
    @ ./array.jl:723
  [6] collect
    @ ./array.jl:839
  [7] _ntuple
    @ ./ntuple.jl:37
  [8] ntuple
    @ ./ntuple.jl:19
  [9] #30
    @ ~/.julia/packages/Lux/SBzKQ/ext/LuxForwardDiffExt/batched_ad.jl:99
 [10] _broadcast_getindex_evalf
    @ ./broadcast.jl:709
 [11] _broadcast_getindex
    @ ./broadcast.jl:682
 [12] getindex
    @ ./broadcast.jl:636
 [13] #34
    @ ~/.julia/packages/GPUArrays/OKkAu/src/host/broadcast.jl:59
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)
Stacktrace:
 [1] ntuple
   @ ./ntuple.jl:19
 [2] #30
   @ ~/.julia/packages/Lux/SBzKQ/ext/LuxForwardDiffExt/batched_ad.jl:99
 [3] _broadcast_getindex_evalf
   @ ./broadcast.jl:709
 [4] _broadcast_getindex
   @ ./broadcast.jl:682
 [5] getindex
   @ ./broadcast.jl:636
 [6] #34
   @ ~/.julia/packages/GPUArrays/OKkAu/src/host/broadcast.jl:59
Reason: unsupported call to an unknown function (call to jl_f__apply_iterate)
Stacktrace:
 [1] _ntuple
   @ ./ntuple.jl:37
 [2] ntuple
   @ ./ntuple.jl:19
 [3] #30
   @ ~/.julia/packages/Lux/SBzKQ/ext/LuxForwardDiffExt/batched_ad.jl:99
 [4] _broadcast_getindex_evalf
   @ ./broadcast.jl:709
 [5] _broadcast_getindex
   @ ./broadcast.jl:682
 [6] getindex
   @ ./broadcast.jl:636
 [7] #34
   @ ~/.julia/packages/GPUArrays/OKkAu/src/host/broadcast.jl:59
Reason: unsupported call to an unknown function (call to julia.pop_gc_frame)
Stacktrace:
 [1] ntuple
   @ ./ntuple.jl:19
 [2] #30
   @ ~/.julia/packages/Lux/SBzKQ/ext/LuxForwardDiffExt/batched_ad.jl:99
 [3] _broadcast_getindex_evalf
   @ ./broadcast.jl:709
 [4] _broadcast_getindex
   @ ./broadcast.jl:682
 [5] getindex
   @ ./broadcast.jl:636
 [6] #34
   @ ~/.julia/packages/GPUArrays/OKkAu/src/host/broadcast.jl:59
Hint: catch this exception as `err` and call `code_typed(err; interactive = true)` to introspect the erronous code with Cthulhu.jl
Stacktrace:
  [1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}, args::LLVM.Module)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/validation.jl:147
  [2] macro expansion
    @ ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:445 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/TimerOutputs/RsWnF/src/TimerOutput.jl:253 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:444 [inlined]
  [5] emit_llvm(job::GPUCompiler.CompilerJob; libraries::Bool, toplevel::Bool, optimize::Bool, cleanup::Bool, only_entry::Bool, validate::Bool)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/utils.jl:92
  [6] emit_llvm
    @ ~/.julia/packages/GPUCompiler/kqxyC/src/utils.jl:86 [inlined]
  [7] codegen(output::Symbol, job::GPUCompiler.CompilerJob; libraries::Bool, toplevel::Bool, optimize::Bool, cleanup::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:134
  [8] codegen
    @ ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:115 [inlined]
  [9] compile(target::Symbol, job::GPUCompiler.CompilerJob; libraries::Bool, toplevel::Bool, optimize::Bool, cleanup::Bool, strip::Bool, validate::Bool, only_entry::Bool)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:111
 [10] compile
    @ ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:103 [inlined]
 [11] #1116
    @ ~/.julia/packages/CUDA/jdJ7Z/src/compiler/compilation.jl:247 [inlined]
 [12] JuliaContext(f::CUDA.var"#1116#1119"{GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}}; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:52
 [13] JuliaContext(f::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:42
 [14] compile(job::GPUCompiler.CompilerJob)
    @ CUDA ~/.julia/packages/CUDA/jdJ7Z/src/compiler/compilation.jl:246
 [15] actual_compilation(cache::Dict{Any, CuFunction}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))
    @ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/execution.jl:128
 [16] cached_compilation(cache::Dict{Any, CuFunction}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}, compiler::Function, linker::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/execution.jl:103
 [17] macro expansion
    @ ~/.julia/packages/CUDA/jdJ7Z/src/compiler/execution.jl:367 [inlined]
 [18] macro expansion
    @ ./lock.jl:267 [inlined]
 [19] cufunction(f::GPUArrays.var"#34#36", tt::Type{Tuple{CUDA.CuKernelContext, CuDeviceVector{ForwardDiff.Partials{11, Float32}, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.Mem.DeviceBuffer}, Tuple{Base.OneTo{Int64}}, LuxForwardDiffExt.var"#30#37"{11, ForwardDiff.Partials{11, Float32}, Float32}, Tuple{Base.Broadcast.Extruded{CuDeviceVector{Int64, 1}, Tuple{Bool}, Tuple{Int64}}}}, Int64}}; kwargs::@Kwargs{})
    @ CUDA ~/.julia/packages/CUDA/jdJ7Z/src/compiler/execution.jl:362
 [20] cufunction
    @ ~/.julia/packages/CUDA/jdJ7Z/src/compiler/execution.jl:359 [inlined]
 [21] macro expansion
    @ ~/.julia/packages/CUDA/jdJ7Z/src/compiler/execution.jl:112 [inlined]
 [22] #launch_heuristic#1173
    @ ~/.julia/packages/CUDA/jdJ7Z/src/gpuarrays.jl:17 [inlined]
 [23] launch_heuristic
    @ ~/.julia/packages/CUDA/jdJ7Z/src/gpuarrays.jl:15 [inlined]
 [24] _copyto!
    @ ~/.julia/packages/GPUArrays/OKkAu/src/host/broadcast.jl:78 [inlined]
 [25] copyto!
    @ ~/.julia/packages/GPUArrays/OKkAu/src/host/broadcast.jl:44 [inlined]
 [26] copy
    @ ~/.julia/packages/GPUArrays/OKkAu/src/host/broadcast.jl:29 [inlined]
 [27] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.Mem.DeviceBuffer}, Nothing, LuxForwardDiffExt.var"#30#37"{11, ForwardDiff.Partials{11, Float32}, Float32}, Tuple{CuArray{Int64, 1, CUDA.Mem.DeviceBuffer}}})
    @ Base.Broadcast ./broadcast.jl:903
 [28] map(::Function, ::CuArray{Int64, 1, CUDA.Mem.DeviceBuffer})
    @ GPUArrays ~/.julia/packages/GPUArrays/OKkAu/src/host/broadcast.jl:102
 [29] __batched_forwarddiff_jacobian_chunk!!(J_partial::Nothing, f::LuxForwardDiffExt.var"#28#29"{Base.Fix2{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, bias::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, @NamedTuple{}}, @NamedTuple{weight::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, bias::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Tuple{Int64, Int64}}, x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Type{ForwardDiff.Tag{LuxForwardDiffExt.var"#28#29"{Base.Fix2{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, bias::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, @NamedTuple{}}, @NamedTuple{weight::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, bias::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Tuple{Int64, Int64}}, Float32}}, ::ForwardDiff.Chunk{11}, ::Type{ForwardDiff.Dual{ForwardDiff.Tag{LuxForwardDiffExt.var"#28#29"{Base.Fix2{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, bias::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, @NamedTuple{}}, @NamedTuple{weight::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, bias::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Tuple{Int64, Int64}}, Float32}, Float32, 11}}, ::Type{ForwardDiff.Partials{11, Float32}}, idx::Int64)
    @ LuxForwardDiffExt ~/.julia/packages/Lux/SBzKQ/ext/LuxForwardDiffExt/batched_ad.jl:99
 [30] __batched_forwarddiff_jacobian(f::LuxForwardDiffExt.var"#28#29"{Base.Fix2{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, bias::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, @NamedTuple{}}, @NamedTuple{weight::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, bias::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Tuple{Int64, Int64}}, x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Type{ForwardDiff.Tag{LuxForwardDiffExt.var"#28#29"{Base.Fix2{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, bias::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, @NamedTuple{}}, @NamedTuple{weight::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, bias::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Tuple{Int64, Int64}}, Float32}}, ck::ForwardDiff.Chunk{11})
    @ LuxForwardDiffExt ~/.julia/packages/Lux/SBzKQ/ext/LuxForwardDiffExt/batched_ad.jl:63
 [31] __batched_jacobian_impl(f::Base.Fix2{StatefulLuxLayer{true, Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}, @NamedTuple{weight::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, bias::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, @NamedTuple{}}, @NamedTuple{weight::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, bias::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, backend::AutoForwardDiff{nothing, Nothing}, x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ LuxForwardDiffExt ~/.julia/packages/Lux/SBzKQ/ext/LuxForwardDiffExt/batched_ad.jl:48
 [32] __batched_jacobian
    @ ~/.julia/packages/Lux/SBzKQ/ext/LuxForwardDiffExt/batched_ad.jl:3 [inlined]
 [33] __batched_jacobian
    @ ~/.julia/packages/Lux/SBzKQ/ext/LuxForwardDiffExt/batched_ad.jl:37 [inlined]
 [34] batched_jacobian
    @ ~/.julia/packages/Lux/SBzKQ/src/helpers/autodiff.jl:122 [inlined]
 [35] (::var"#12#14")(self#3262::@NamedTuple{potential::Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}}, x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ps#3263::@NamedTuple{potential::@NamedTuple{weight::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, bias::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, st#3264::@NamedTuple{potential::@NamedTuple{}})
    @ Main ./none:0
 [36] (::CompactLuxLayer{nothing, var"#12#14", @NamedTuple{potential::Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}}, @NamedTuple{potential::Dense{true, typeof(gelu), typeof(glorot_uniform), typeof(zeros32)}}, Lux.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}})(x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ps::@NamedTuple{potential::@NamedTuple{weight::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, bias::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, st::@NamedTuple{potential::@NamedTuple{}})
    @ Lux ~/.julia/packages/Lux/SBzKQ/src/helpers/compact.jl:422
 [37] test_forward(N::Int64; b::Int64)
    @ Main ~/GitHub/G2TWIN/dev/test.jl:19
 [38] test_forward(N::Int64)
    @ Main ~/GitHub/G2TWIN/dev/test.jl:11
 [39] top-level scope
    @ string:1
 [40] eval
    @ ./boot.jl:385 [inlined]
 [41] include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String)
    @ Base ./loading.jl:2076
 [42] include_string
    @ ./loading.jl:2086 [inlined]
 [43] include_string(m::Module, txt::String)
    @ Base ./loading.jl:2086
 [44] top-level scope
    @ REPL[13]:1
 [45] top-level scope
    @ ~/.julia/packages/CUDA/jdJ7Z/src/initialization.jl:206
in expression starting at string:1
avik-pal commented 4 months ago

That PR will fix it.

1 side note

model = @compact(; potential=Dense(N => N, gelu)) do x
        jac_pot = batched_jacobian(potential, AutoForwardDiff(), x)
        return jac_pot, potential.st
end

should be

model = @compact(; potential=Dense(N => N, gelu)) do x
        jac_pot = batched_jacobian(potential, AutoForwardDiff(), x)
        return jac_pot
end
aksuhton commented 4 months ago

Thank you for the fix and the tip!