Broadcasting a function returning an anonymous function with a constructor over CUDA arrays fails to compile, "not isbits" #2514

Closed BioTurboNick closed 3 days ago

BioTurboNick commented 1 week ago

Describe the bug

When broadcasting a function that returns a function over arguments, the broadcast fails to compile with the following error:

ERROR: GPU compilation of MethodInstance for (::GPUArrays.var"#34#36")(::CUDA.CuKernelContext, ::CuDeviceVector{…}, ::Base.Broadcast.Broadcasted{…}, ::Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.Mem.DeviceBuffer}, Tuple{Base.OneTo{Int64}}, var"#3#4"{Type{Bar}}, Tuple{Base.Broadcast.Extruded{CuDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{CuDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}}}, which is not isbits:
  .f is of type var"#3#4"{Type{Bar}} which is not isbits.
    .f is of type Type{Bar} which is not isbits.

  [1] check_invocation(job::GPUCompiler.CompilerJob)
    @ GPUCompiler C:\Users\nicho\.julia\packages\GPUCompiler\U36Ed\src\validation.jl:92
  [2] macro expansion
    @ C:\Users\nicho\.julia\packages\GPUCompiler\U36Ed\src\driver.jl:123 [inlined]
  [3] macro expansion
    @ C:\Users\nicho\.julia\packages\TimerOutputs\Lw5SP\src\TimerOutput.jl:253 [inlined]
    @ GPUCompiler C:\Users\nicho\.julia\packages\GPUCompiler\U36Ed\src\driver.jl:121
  [5] codegen
    @ C:\Users\nicho\.julia\packages\GPUCompiler\U36Ed\src\driver.jl:110 [inlined]
    @ GPUCompiler C:\Users\nicho\.julia\packages\GPUCompiler\U36Ed\src\driver.jl:106
  [7] compile
    @ C:\Users\nicho\.julia\packages\GPUCompiler\U36Ed\src\driver.jl:98 [inlined]
  [8] #1072
    @ C:\Users\nicho\.julia\packages\CUDA\htRwP\src\compiler\compilation.jl:247 [inlined]
  [9] JuliaContext(f::CUDA.var"#1072#1075"{GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}})
    @ GPUCompiler C:\Users\nicho\.julia\packages\GPUCompiler\U36Ed\src\driver.jl:47
 [10] compile(job::GPUCompiler.CompilerJob)
    @ CUDA C:\Users\nicho\.julia\packages\CUDA\htRwP\src\compiler\compilation.jl:246
 [11] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))
    @ GPUCompiler C:\Users\nicho\.julia\packages\GPUCompiler\U36Ed\src\execution.jl:125
 [12] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
    @ GPUCompiler C:\Users\nicho\.julia\packages\GPUCompiler\U36Ed\src\execution.jl:103
 [13] macro expansion
    @ C:\Users\nicho\.julia\packages\CUDA\htRwP\src\compiler\execution.jl:367 [inlined]
 [14] macro expansion
    @ .\lock.jl:267 [inlined]
 [15] cufunction(f::GPUArrays.var"#34#36", tt::Type{Tuple{…}}; kwargs::@Kwargs{})
    @ CUDA C:\Users\nicho\.julia\packages\CUDA\htRwP\src\compiler\execution.jl:362
 [16] cufunction(f::GPUArrays.var"#34#36", tt::Type{Tuple{…}})
    @ CUDA C:\Users\nicho\.julia\packages\CUDA\htRwP\src\compiler\execution.jl:359
 [17] macro expansion
    @ C:\Users\nicho\.julia\packages\CUDA\htRwP\src\compiler\execution.jl:112 [inlined]
 [18] #launch_heuristic#1122
    @ C:\Users\nicho\.julia\packages\CUDA\htRwP\src\gpuarrays.jl:17 [inlined]
 [19] launch_heuristic
    @ C:\Users\nicho\.julia\packages\CUDA\htRwP\src\gpuarrays.jl:15 [inlined]
 [20] _copyto!
    @ C:\Users\nicho\.julia\packages\GPUArrays\OqrUV\src\host\broadcast.jl:78 [inlined]
 [21] copyto!
    @ C:\Users\nicho\.julia\packages\GPUArrays\OqrUV\src\host\broadcast.jl:44 [inlined]
 [22] copy
    @ C:\Users\nicho\.julia\packages\GPUArrays\OqrUV\src\host\broadcast.jl:29 [inlined]
 [23] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{…}, Nothing, var"#3#4"{…}, Tuple{…}})
    @ Base.Broadcast .\broadcast.jl:903
 [24] top-level scope
    @ REPL[11]:1
 [25] top-level scope
    @ C:\Users\nicho\.julia\packages\CUDA\htRwP\src\initialization.jl:206
Some type information was truncated. Use `show(err)` to see complete types.

To reproduce

The Minimal Working Example (MWE) for this bug:

using CUDA

struct Bar{T}

foo(f) = (args...) -> f(args...)

a = cu(zeros(5)); b = cu(ones(5)); c = Bar; d = foo(c)

c.(a, b) # works, produces GPU array

foo(c).(collect(a), collect(b)) # works, produces CPU array

((args...) -> Bar(args...)).(a, b) # works, produces GPU array

foo(c).(a, b) # fails

Expected behavior

c.(a, b) == foo(c).(a, b)

Version info

Details on Julia:

Julia Version 1.10.4
Commit 48d4fd4843 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 32 × 13th Gen Intel(R) Core(TM) i9-13900KF
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, goldmont)
Threads: 1 default, 0 interactive, 1 GC (on 32 virtual cores)

Details on CUDA:

CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.94.0

CUDA libraries:
- CUBLAS: 12.6.3
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.94

Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.3+0
- CUDA_Runtime_jll: 0.15.3+0

- Julia: 1.10.4
- LLVM: 15.0.7

1 device:
  0: NVIDIA GeForce RTX 4080 (sm_89, 13.307 GiB / 15.992 GiB available)

Additional context

Encountered when Zygote differentiates over a broadcast of the form Bar.(a, b), via broadcast_forward where a and b are GPU arrays.

maleadt commented 3 days ago

Further reduced MWE:

struct Bar{T}

function main()
    a = cu(zeros(5))

    capture = Bar
    function closure(arg)

    function kernel(f, x)
    @cuda kernel(closure, a)

So the problem is that you're passing a closure (closure in my MWE, the result of foo(c) in yours) which captures a type-unstable variable (capture resp. c, both ::Type{Bar} with unbound type vars). Because that type is contained in a closure, it doesn't pass Core.Compiler.isconstType, so we don't filter it out during validation.

We could make the validation more lenient again by having it actually consider whether the type-unstable argument is unused in the LLVM IR -- something we removed after Core.Compiler.isconstType in https://github.com/JuliaGPU/GPUCompiler.jl/pull/24 -- however that only fixes my MWE here, and not yours, because the Broadcasted argument (which is now also type-unstable) is used. So that approach doesn't cut it.

Maybe we should approach this differently. Julia converted this Broadcast to a { [1 x {}*], [2 x { { i8 addrspace(1)*, i64, [1 x i64], i64 }, [1 x i8], [1 x i64] }], [1 x [1 x i64]] } %1, but it seems hard to check whether that managed pointer in there is the only unused field...

All that said, I'm not sure it's worth the effort, because the generated broadcast kernel is broken anyway: Because of the (inferred) type instability, the broadcast returns Any, which doesn't work on the GPU anyway:

ERROR: ArgumentError: Broadcast operation resulting in Any is not GPU compatible
 [1] _copyto!
   @ ~/.julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:86 [inlined]
 [2] copyto!
   @ ~/.julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:44 [inlined]
 [3] copy
   @ ~/.julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:29 [inlined]
 [4] materialize
   @ ./broadcast.jl:903 [inlined]
maleadt commented 3 days ago

We could make the validation more lenient again by having it actually consider whether the type-unstable argument is unused in the LLVM IR -- something we removed after Core.Compiler.isconstType in JuliaGPU/GPUCompiler.jl#24

FWIW, that looks like:

diff --git a/src/driver.jl b/src/driver.jl
index 9e05eb6..a4cff8f 100644
--- a/src/driver.jl
+++ b/src/driver.jl
@@ -88,8 +88,7 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool

     @timeit_debug to "Validation" begin
-        check_method(job)   # not optional
-        validate && check_invocation(job)
+        check_method(job)

@@ -99,6 +98,10 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool

     ir, ir_meta = emit_llvm(job; libraries, toplevel, optimize, cleanup, only_entry, validate)

+    validate && @timeit_debug to "Validation" begin
+        check_invocation(job, ir_meta.entry)
+    end
     if output == :llvm
         if strip
             @timeit_debug to "strip debug info" strip_debuginfo!(ir)
diff --git a/src/validation.jl b/src/validation.jl
index e1a355b..9f1f869 100644
--- a/src/validation.jl
+++ b/src/validation.jl
@@ -66,7 +66,7 @@ function explain_nonisbits(@nospecialize(dt), depth=1; maxdepth=10)
     return msg

-function check_invocation(@nospecialize(job::CompilerJob))
+function check_invocation(@nospecialize(job::CompilerJob), entry::LLVM.Function)
     sig = job.source.specTypes
     ft = sig.parameters[1]
     tt = Tuple{sig.parameters[2:end]...}
@@ -77,6 +77,9 @@ function check_invocation(@nospecialize(job::CompilerJob))
     real_arg_i = 0

     for (arg_i,dt) in enumerate(sig.parameters)
+        println(Core.stdout, arg_i)
+        println(Core.stdout, dt)
         isghosttype(dt) && continue
         Core.Compiler.isconstType(dt) && continue
         real_arg_i += 1
@@ -89,9 +92,13 @@ function check_invocation(@nospecialize(job::CompilerJob))

         if !isbitstype(dt)
-            throw(KernelError(job, "passing and using non-bitstype argument",
-                """Argument $arg_i to your kernel function is of type $dt, which is not isbits:
-                    $(explain_nonisbits(dt))"""))
+            param = parameters(entry)[real_arg_i]
+            if !isempty(uses(param))
+                println(Core.stdout, string(entry))
+                throw(KernelError(job, "passing and using non-bitstype argument",
+                      """Argument $arg_i to your kernel function is of type $dt, which is not isbits:
+                         $(explain_nonisbits(dt))"""))
+             end
BioTurboNick commented 3 days ago

Is there maybe a way to avoid the type instability? Or should closures of this kind be warned against, and e.g. Zygote makes changes accordingly?

maleadt commented 3 days ago

Zygote capturing all kinds of things in closures is definitely not great, but it's too late to fix that.

The type instability was a red herring though, even typing c wouldn't resolve this because the type object itself is the problematic one:

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.DeviceMemory}, Tuple{Base.OneTo{Int64}}, var"#16#18"{Type{Bar{Float32}}}, Tuple{Base.Broadcast.Extruded{CuDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{CuDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}}}, which is not isbits:
  .f is of type var"#16#18"{Type{Bar{Float32}}} which is not isbits.
    .f is of type Type{Bar{Float32}} which is not isbits.
BioTurboNick commented 3 days ago

Okay, yeah. So I think on the Zygote side there's a way to work around the specific issue I encountered. Aside from trying to fix anything in CUDA, is there an opportunity to provide a more helpful error message in this case? I would be willing to work on that.

maleadt commented 3 days ago

Isn't it already relatively helpful, pointing to the exact argument and field?

In terms of actually fixing this, it would be possible to completely disable validation, as this code turns out to result in relatively compatible LLVM IR. But then we open the door towards accidentally using GPU pointers (from boxed objects), which is what this check was designed to combat...

BioTurboNick commented 3 days ago

It's helpful if you already understand the internals, perhaps.

As naive user: "Okay great, Type{Bar} is not isbits. What do I do with this information? How do I use the error to correct the code?"