FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Cannot take gradient of sort on 2D CuArray #1499

Open ngphuoc opened 9 months ago

ngphuoc commented 9 months ago

Taking gradient of sort on 2D CuArray with dims keyword gives the below error. An MWE for this bug:

using CUDA, Zygote
x = CUDA.rand(3)
gradient(x -> sum(sort(x)), x)  # OK
x = CUDA.rand(3, 4)
gradient(x -> sum(sort(x, dims=1)), x)  # error
ERROR: `llvmcall` must be compiled to be called
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(::Zygote.Context{false}, ::Core.IntrinsicFunction, ::String, ::Type{Int64}, ::Type{Tuple{Ptr{Int64}}}, ::Ptr{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:81
  [3] getindex
    @ ./atomics.jl:358 [inlined]
  [4] getindex
    @ ~/.julia/packages/GPUArrays/Hd5Sk/src/host/abstractarray.jl:48 [inlined]
  [5] _pullback(ctx::Zygote.Context{false}, f::typeof(getindex), args::GPUArrays.RefCounted{CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
  [6] getindex
    @ ~/.julia/packages/GPUArrays/Hd5Sk/src/host/abstractarray.jl:72 [inlined]
  [7] is_unified
    @ ~/.julia/packages/CUDA/htRwP/src/array.jl:150 [inlined]
  [8] adapt_storage
    @ ~/.julia/packages/CUDA/htRwP/src/compiler/execution.jl:143 [inlined]
  [9] _pullback(::Zygote.Context{…}, ::typeof(Adapt.adapt_storage), ::CUDA.KernelAdaptor, ::CuArray{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [10] adapt_structure
    @ Adapt ~/.julia/packages/Adapt/2aJr7/src/Adapt.jl:57 [inlined]
 [11] adapt
    @ Adapt ~/.julia/packages/Adapt/2aJr7/src/Adapt.jl:40 [inlined]
 [12] cudaconvert
    @ CUDA ~/.julia/packages/CUDA/htRwP/src/compiler/execution.jl:196 [inlined]
 [13] #662
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:187 [inlined]
 [14] map
    @ Base ./tuple.jl:294 [inlined]
 [15] ∇map(cx::Zygote.Context{…}, f::typeof(cudaconvert), args::Tuple{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:187
 [16] adjoint
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:213 [inlined]
 [17] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [18] macro expansion
    @ ~/.julia/packages/CUDA/htRwP/src/compiler/execution.jl:110 [inlined]
 [19] #quicksort!#9
    @ ~/.julia/packages/CUDA/htRwP/src/sorting.jl:475 [inlined]
 [20] _pullback(::Zygote.Context{…}, ::CUDA.QuickSortImpl.var"##quicksort!#9", ::typeof(isless), ::typeof(identity), ::Int64, ::Nothing, ::Int64, ::typeof(CUDA.QuickSortImpl.quicksort!), ::CuArray{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [21] quicksort!
    @ ~/.julia/packages/CUDA/htRwP/src/sorting.jl:462 [inlined]
 [22] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(CUDA.QuickSortImpl.quicksort!), ::CuArray{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [23] #sort!#1179
    @ ~/.julia/packages/CUDA/htRwP/src/sorting.jl:947 [inlined]
 [24] sort!
    @ ~/.julia/packages/CUDA/htRwP/src/sorting.jl:941 [inlined]
 [25] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(sort!), ::CuArray{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [26] #sort#1180
    @ ~/.julia/packages/CUDA/htRwP/src/sorting.jl:952 [inlined]
 [27] _pullback(::Zygote.Context{…}, ::CUDA.var"##sort#1180", ::@Kwargs{…}, ::typeof(sort), ::CuArray{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [28] sort
    @ CUDA ~/.julia/packages/CUDA/htRwP/src/sorting.jl:951 [inlined]
 [29] #3
    @ ./REPL[6]:1 [inlined]
 [30] _pullback(ctx::Zygote.Context{false}, f::var"#3#4", args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [31] pullback(f::Function, cx::Zygote.Context{false}, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:90
 [32] pullback(f::Any, cx::ZygoteRules.AContext, args::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:88 [inlined]
 [33] gradient(f::Function, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:147
 [34] top-level scope
    @ REPL[6]:1
 [35] top-level scope
    @ ~/.julia/packages/CUDA/htRwP/src/initialization.jl:206
Some type information was truncated. Use `show(err)` to see complete types.
Project.toml

``` Project.toml` [052768ef] CUDA v5.2.0 [e88e6eb3] Zygote v0.6.69 ```

Details on Julia:

julia> versioninfo()
Julia Version 1.10.0
Commit 3120989f39b (2023-12-25 18:01 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 96 × Intel(R) Xeon(R) Platinum 8168 CPU @ 2.70GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, skylake-avx512)
  Threads: 143 on 96 virtual cores
Environment:
  JULIA_CONDAPKG_BACKEND = Current
  JULIA_DEPOT_PATH = ~/.julia
  JULIA_HOME = ~/julia-1.8.0
  JULIA_NUM_THREADS = auto
  LD_LIBRARY_PATH = /lib/x86_64-linux-gnu:/opt/local/cuda-10.1/lib64:/AvaStore/opt/nvidia/hpc_sdk/Linux_x86_64/20.7/cuda/11.0/lib64::/usr/local/cuda-11.3/lib64:/AvaStore/opt/nvidia/hpc_sdk/Linux_x86_64/20.7/cuda/11.0/lib64:/usr/local/cuda-11.3/lib64:/AvaStore/opt/nvidia/hpc_sdk/Linux_x86_64/20.7/cuda/11.0/lib64

Details on CUDA:

CUDA runtime 12.3, artifact installation
CUDA driver 12.3
NVIDIA driver 450.248.2, originally for CUDA 11.0

CUDA libraries:
- CUBLAS: 12.3.4
- CURAND: 10.3.4
- CUFFT: 11.0.12
- CUSOLVER: 11.5.4
- CUSPARSE: 12.2.0
- CUPTI: 21.0.0
- NVML: 11.0.0+450.248.2

Julia packages:
- CUDA: 5.2.0
- CUDA_Driver_jll: 0.7.0+1
- CUDA_Runtime_jll: 0.11.1+0

Toolchain:
- Julia: 1.10.0
- LLVM: 15.0.7

16 devices:
  0: Tesla V100-SXM3-32GB (sm_70, 17.722 GiB / 31.749 GiB available)
  ...
ToucheSir commented 9 months ago

ChainRules.jl only has a rule for sort on vectors. Adding one for sort on higher-dimensional arrays there should do the trick. I'd recommend opening an issue (or better yet, a PR) there.