FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 210 forks source link

`Zygote` doesn't properly work with `Metal.jl` and half precision. #1482

Open benedict-96 opened 9 months ago

benedict-96 commented 9 months ago

I get a very long and complicated error when I do:

using Zygote: gradient
import Metal
using LinearAlgebra: norm

gradient(p -> norm(tanh.(p)), Metal.rand(Float16, 10))[1]

This however works with MtlArray{Float32} as well as CuArray{Float16}.

mcabbott commented 9 months ago

Please always post the error message.

Running this, I get the same error from just constructing the array:

julia> gradient(norm, Metal.rand(Float16, 10))[1]
ERROR: UndefVarError: `propagate_julia_addrsp!` not defined in `GPUCompiler`
Suggestion: check for spelling errors or missing imports.
Stacktrace:
  [1] addOptimizationPasses!(pm::LLVM.ModulePassManager, opt_level::Int64)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/optim.jl:49

julia> Metal.rand(Float16, 10)
ERROR: UndefVarError: `propagate_julia_addrsp!` not defined in `GPUCompiler`
Suggestion: check for spelling errors or missing imports.
Stacktrace:
  [1] addOptimizationPasses!(pm::LLVM.ModulePassManager, opt_level::Int64)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/optim.jl:49
  [2] addOptimizationPasses!(pm::LLVM.ModulePassManager, opt_level::Int64)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/optim.jl:24 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/GPUCompiler/YO8Uj/src/optim.jl:183 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/LLVM/vIbji/src/base.jl:98 [inlined]
  [5] optimize!(job::GPUCompiler.CompilerJob, mod::LLVM.Module)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/optim.jl:181
  [6] macro expansion
    @ ~/.julia/packages/GPUCompiler/YO8Uj/src/driver.jl:338 [inlined]
  [7] macro expansion
    @ ~/.julia/packages/TimerOutputs/RsWnF/src/TimerOutput.jl:253 [inlined]
  [8] macro expansion
    @ ~/.julia/packages/GPUCompiler/YO8Uj/src/driver.jl:337 [inlined]
  [9] macro expansion
    @ ~/.julia/packages/TimerOutputs/RsWnF/src/TimerOutput.jl:253 [inlined]
 [10] macro expansion
    @ ~/.julia/packages/GPUCompiler/YO8Uj/src/driver.jl:311 [inlined]
 [11] 
    @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:89
 [12] 
    @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/driver.jl:129
 [13] codegen
    @ ~/.julia/packages/GPUCompiler/YO8Uj/src/driver.jl:110 [inlined]
 [14] 
    @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/driver.jl:106
 [15] compile
    @ ~/.julia/packages/GPUCompiler/YO8Uj/src/driver.jl:98 [inlined]
 [16] #51
    @ ~/.julia/packages/Metal/qeZqc/src/compiler/compilation.jl:57 [inlined]
 [17] JuliaContext(f::Metal.var"#51#52"{GPUCompiler.CompilerJob{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/driver.jl:47
 [18] compile(job::GPUCompiler.CompilerJob)
    @ Metal ~/.julia/packages/Metal/qeZqc/src/compiler/compilation.jl:56
 [19] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(Metal.compile), linker::typeof(Metal.link))
    @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/execution.jl:125
 [20] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/execution.jl:103
 [21] macro expansion
    @ ~/.julia/packages/Metal/qeZqc/src/compiler/execution.jl:162 [inlined]
 [22] macro expansion
    @ ./lock.jl:267 [inlined]
 [23] mtlfunction(f::GPUArrays.var"#98#99"{Float16}, tt::Type{Tuple{…}}; name::Nothing, kwargs::@Kwargs{})
    @ Metal ~/.julia/packages/Metal/qeZqc/src/compiler/execution.jl:157
 [24] mtlfunction
    @ ~/.julia/packages/Metal/qeZqc/src/compiler/execution.jl:155 [inlined]
 [25] macro expansion
    @ ~/.julia/packages/Metal/qeZqc/src/compiler/execution.jl:77 [inlined]
 [26] #launch_heuristic#98
    @ ~/.julia/packages/Metal/qeZqc/src/gpuarrays.jl:14 [inlined]
 [27] launch_heuristic
    @ ~/.julia/packages/Metal/qeZqc/src/gpuarrays.jl:12 [inlined]
 [28] gpu_call(::GPUArrays.var"#98#99"{…}, ::Metal.MtlVector{…}, ::Metal.MtlVector{…}; target::Metal.MtlVector{…}, elements::Nothing, threads::Nothing, blocks::Nothing, name::Nothing)
    @ GPUArrays ~/.julia/packages/GPUArrays/5XhED/src/device/execution.jl:61
 [29] gpu_call(::GPUArrays.var"#98#99"{…}, ::Metal.MtlVector{…}, ::Metal.MtlVector{…})
    @ GPUArrays ~/.julia/packages/GPUArrays/5XhED/src/device/execution.jl:34
 [30] rand!(rng::GPUArrays.RNG, A::Metal.MtlVector{Float16, Metal.MTL.MTLResourceStorageModePrivate})
    @ GPUArrays ~/.julia/packages/GPUArrays/5XhED/src/host/random.jl:87
 [31] rand!(A::Metal.MtlVector{Float16, Metal.MTL.MTLResourceStorageModePrivate})
    @ Metal ~/.julia/packages/Metal/qeZqc/src/random.jl:6
 [32] rand(::Type, ::Int64; storage::Metal.MTL.MTLResourceOptions)
    @ Metal ~/.julia/packages/Metal/qeZqc/src/random.jl:14
 [33] top-level scope
    @ REPL[47]:1
 [34] top-level scope
    @ ~/.julia/packages/Metal/qeZqc/src/initialization.jl:51
Some type information was truncated. Use `show(err)` to see complete types.

julia> Metal.rand(Float32, 10)
ERROR: UndefVarError: `propagate_julia_addrsp!` not defined in `GPUCompiler`
Suggestion: check for spelling errors or missing imports.
Stacktrace:
  [1] addOptimizationPasses!(pm::LLVM.ModulePassManager, opt_level::Int64)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/optim.jl:49

Here, with a constructor which runs, norm fails:

julia> Metal.MtlArray(rand(Float16, 10))
10-element Metal.MtlVector{Float16, Metal.MTL.MTLResourceStorageModePrivate}:
 0.01514
 0.2837
 0.66
 0.2153
 0.2637
 0.7236
 0.8022
 0.865
 0.3765
 0.03662

julia> norm(Metal.MtlArray(rand(Float16, 10)))
ERROR: UndefVarError: `propagate_julia_addrsp!` not defined in `GPUCompiler`
Suggestion: check for spelling errors or missing imports.
Stacktrace:
  [1] addOptimizationPasses!(pm::LLVM.ModulePassManager, opt_level::Int64)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/optim.jl:49
...
benedict-96 commented 9 months ago

Thanks for the quick response and sorry for not providing the error message! For me norm(Metal.MtlArray(rand(Float16, 10))) does not fail (Mac mini 2023 with M2 and Metal v0.5.1). Metal.rand(Float16, 10) and norm(tanh.(Metal.rand(Float16, 10))) also don't fail.

Here the error message with Zygote:

julia> gradient(p -> norm(tanh.(p)), Metal.rand(Float16, 10))[1]
ERROR: InvalidIRError: compiling MethodInstance for (::GPUArrays.var"#broadcast_kernel#38")(::Metal.mtlKernelContext, ::MtlDeviceVector{Float16, 1}, ::Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{MtlDeviceVector{Float16, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{1}, Nothing, typeof(conj), Tuple{Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{1}, Nothing, typeof(-), Tuple{Int64, Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{1}, Nothing, typeof(Base.literal_pow), Tuple{Metal.MtlRefValue{typeof(^)}, Base.Broadcast.Extruded{MtlDeviceVector{Float16, 1}, Tuple{Bool}, Tuple{Int64}}, Metal.MtlRefValue{Val{2}}}}}}}}}}, ::Int64) resulted in invalid LLVM IR
Reason: unsupported use of double value
Stacktrace:
  [1] ^
    @ ./math.jl:1198
  [2] ^
    @ ./math.jl:1200
  [3] literal_pow
    @ ./intfuncs.jl:325
  [4] _broadcast_getindex_evalf
    @ ./broadcast.jl:683
  [5] _broadcast_getindex
    @ ./broadcast.jl:656
  [6] _getindex
    @ ./broadcast.jl:680
  [7] _getindex
    @ ./broadcast.jl:679
  [8] _broadcast_getindex
    @ ./broadcast.jl:655
  [9] _getindex
    @ ./broadcast.jl:680
 [10] _broadcast_getindex
    @ ./broadcast.jl:655
 [11] _getindex
    @ ./broadcast.jl:680
 [12] _getindex
    @ ./broadcast.jl:679
 [13] _broadcast_getindex
    @ ./broadcast.jl:655
 [14] getindex
    @ ./broadcast.jl:610
 [15] broadcast_kernel
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:64
mcabbott commented 9 months ago

Ok, I don't know what's wrong with my installation then.

But the error indicates something is making an intermediate Float64 somewhere. The gradient rule for tanh.(x) is here and the rule for norm is I think here, in CR. Can you isolate which operations cause this promotion?

benedict-96 commented 9 months ago

Oh, sorry; should have checked this before. The problem seems to be p -> tanh.(p). gradient(p -> norm(p), Metal.rand(Float16, 10))[1] works, but

julia> pullback(p -> tanh.(p), Metal.rand(Float16, 10))[2](Metal.rand(Float16, 10))
ERROR: InvalidIRError: compiling MethodInstance for (::GPUArrays.var"#broadcast_kernel#38")(::Metal.mtlKernelContext, ::MtlDeviceVector{Float16, 1}, ::Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{MtlDeviceVector{Float16, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{1}, Nothing, typeof(conj), Tuple{Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{1}, Nothing, typeof(-), Tuple{Int64, Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{1}, Nothing, typeof(Base.literal_pow), Tuple{Metal.MtlRefValue{typeof(^)}, Base.Broadcast.Extruded{MtlDeviceVector{Float16, 1}, Tuple{Bool}, Tuple{Int64}}, Metal.MtlRefValue{Val{2}}}}}}}}}}, ::Int64) resulted in invalid LLVM IR
Reason: unsupported use of double value
Stacktrace:
  [1] ^
    @ ./math.jl:1198
  [2] ^
    @ ./math.jl:1200
  [3] literal_pow
    @ ./intfuncs.jl:325
  [4] _broadcast_getindex_evalf
    @ ./broadcast.jl:683
  [5] _broadcast_getindex
    @ ./broadcast.jl:656
  [6] _getindex
    @ ./broadcast.jl:680
  [7] _getindex
    @ ./broadcast.jl:679
  [8] _broadcast_getindex
    @ ./broadcast.jl:655
  [9] _getindex
    @ ./broadcast.jl:680
 [10] _broadcast_getindex
    @ ./broadcast.jl:655
 [11] _getindex
    @ ./broadcast.jl:680
 [12] _getindex
    @ ./broadcast.jl:679
 [13] _broadcast_getindex
    @ ./broadcast.jl:655
 [14] getindex
    @ ./broadcast.jl:610
 [15] broadcast_kernel
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:64