Open ngphuoc opened 9 months ago
Taking gradient of sort on 2D CuArray with dims keyword gives the below error. An MWE for this bug:
sort
dims
using CUDA, Zygote x = CUDA.rand(3) gradient(x -> sum(sort(x)), x) # OK x = CUDA.rand(3, 4) gradient(x -> sum(sort(x, dims=1)), x) # error ERROR: `llvmcall` must be compiled to be called Stacktrace: [1] macro expansion @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0 [inlined] [2] _pullback(::Zygote.Context{false}, ::Core.IntrinsicFunction, ::String, ::Type{Int64}, ::Type{Tuple{Ptr{Int64}}}, ::Ptr{Int64}) @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:81 [3] getindex @ ./atomics.jl:358 [inlined] [4] getindex @ ~/.julia/packages/GPUArrays/Hd5Sk/src/host/abstractarray.jl:48 [inlined] [5] _pullback(ctx::Zygote.Context{false}, f::typeof(getindex), args::GPUArrays.RefCounted{CUDA.Mem.DeviceBuffer}) @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0 [6] getindex @ ~/.julia/packages/GPUArrays/Hd5Sk/src/host/abstractarray.jl:72 [inlined] [7] is_unified @ ~/.julia/packages/CUDA/htRwP/src/array.jl:150 [inlined] [8] adapt_storage @ ~/.julia/packages/CUDA/htRwP/src/compiler/execution.jl:143 [inlined] [9] _pullback(::Zygote.Context{…}, ::typeof(Adapt.adapt_storage), ::CUDA.KernelAdaptor, ::CuArray{…}) @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0 [10] adapt_structure @ Adapt ~/.julia/packages/Adapt/2aJr7/src/Adapt.jl:57 [inlined] [11] adapt @ Adapt ~/.julia/packages/Adapt/2aJr7/src/Adapt.jl:40 [inlined] [12] cudaconvert @ CUDA ~/.julia/packages/CUDA/htRwP/src/compiler/execution.jl:196 [inlined] [13] #662 @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:187 [inlined] [14] map @ Base ./tuple.jl:294 [inlined] [15] ∇map(cx::Zygote.Context{…}, f::typeof(cudaconvert), args::Tuple{…}) @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:187 [16] adjoint @ ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:213 [inlined] [17] _pullback @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined] [18] macro expansion @ ~/.julia/packages/CUDA/htRwP/src/compiler/execution.jl:110 [inlined] [19] #quicksort!#9 @ ~/.julia/packages/CUDA/htRwP/src/sorting.jl:475 [inlined] [20] _pullback(::Zygote.Context{…}, ::CUDA.QuickSortImpl.var"##quicksort!#9", ::typeof(isless), ::typeof(identity), ::Int64, ::Nothing, ::Int64, ::typeof(CUDA.QuickSortImpl.quicksort!), ::CuArray{…}) @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0 [21] quicksort! @ ~/.julia/packages/CUDA/htRwP/src/sorting.jl:462 [inlined] [22] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(CUDA.QuickSortImpl.quicksort!), ::CuArray{…}) @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0 [23] #sort!#1179 @ ~/.julia/packages/CUDA/htRwP/src/sorting.jl:947 [inlined] [24] sort! @ ~/.julia/packages/CUDA/htRwP/src/sorting.jl:941 [inlined] [25] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(sort!), ::CuArray{…}) @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0 [26] #sort#1180 @ ~/.julia/packages/CUDA/htRwP/src/sorting.jl:952 [inlined] [27] _pullback(::Zygote.Context{…}, ::CUDA.var"##sort#1180", ::@Kwargs{…}, ::typeof(sort), ::CuArray{…}) @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0 [28] sort @ CUDA ~/.julia/packages/CUDA/htRwP/src/sorting.jl:951 [inlined] [29] #3 @ ./REPL[6]:1 [inlined] [30] _pullback(ctx::Zygote.Context{false}, f::var"#3#4", args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}) @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0 [31] pullback(f::Function, cx::Zygote.Context{false}, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}) @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:90 [32] pullback(f::Any, cx::ZygoteRules.AContext, args::Vararg{Any}) @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:88 [inlined] [33] gradient(f::Function, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}) @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:147 [34] top-level scope @ REPL[6]:1 [35] top-level scope @ ~/.julia/packages/CUDA/htRwP/src/initialization.jl:206 Some type information was truncated. Use `show(err)` to see complete types.
``` Project.toml` [052768ef] CUDA v5.2.0 [e88e6eb3] Zygote v0.6.69 ```
Details on Julia:
julia> versioninfo() Julia Version 1.10.0 Commit 3120989f39b (2023-12-25 18:01 UTC) Build Info: Official https://julialang.org/ release Platform Info: OS: Linux (x86_64-linux-gnu) CPU: 96 × Intel(R) Xeon(R) Platinum 8168 CPU @ 2.70GHz WORD_SIZE: 64 LIBM: libopenlibm LLVM: libLLVM-15.0.7 (ORCJIT, skylake-avx512) Threads: 143 on 96 virtual cores Environment: JULIA_CONDAPKG_BACKEND = Current JULIA_DEPOT_PATH = ~/.julia JULIA_HOME = ~/julia-1.8.0 JULIA_NUM_THREADS = auto LD_LIBRARY_PATH = /lib/x86_64-linux-gnu:/opt/local/cuda-10.1/lib64:/AvaStore/opt/nvidia/hpc_sdk/Linux_x86_64/20.7/cuda/11.0/lib64::/usr/local/cuda-11.3/lib64:/AvaStore/opt/nvidia/hpc_sdk/Linux_x86_64/20.7/cuda/11.0/lib64:/usr/local/cuda-11.3/lib64:/AvaStore/opt/nvidia/hpc_sdk/Linux_x86_64/20.7/cuda/11.0/lib64
Details on CUDA:
CUDA runtime 12.3, artifact installation CUDA driver 12.3 NVIDIA driver 450.248.2, originally for CUDA 11.0 CUDA libraries: - CUBLAS: 12.3.4 - CURAND: 10.3.4 - CUFFT: 11.0.12 - CUSOLVER: 11.5.4 - CUSPARSE: 12.2.0 - CUPTI: 21.0.0 - NVML: 11.0.0+450.248.2 Julia packages: - CUDA: 5.2.0 - CUDA_Driver_jll: 0.7.0+1 - CUDA_Runtime_jll: 0.11.1+0 Toolchain: - Julia: 1.10.0 - LLVM: 15.0.7 16 devices: 0: Tesla V100-SXM3-32GB (sm_70, 17.722 GiB / 31.749 GiB available) ...
ChainRules.jl only has a rule for sort on vectors. Adding one for sort on higher-dimensional arrays there should do the trick. I'd recommend opening an issue (or better yet, a PR) there.
Taking gradient of
sort
on 2D CuArray withdims
keyword gives the below error. An MWE for this bug:Project.toml
``` Project.toml` [052768ef] CUDA v5.2.0 [e88e6eb3] Zygote v0.6.69 ```
Details on Julia:
Details on CUDA: