FluxML / Optimisers.jl

Optimisers.jl defines many standard optimisers and utilities for learning loops.
https://fluxml.ai/Optimisers.jl
MIT License
75 stars 22 forks source link

Optimisers.update fails with gradient of type `CUDA.CUSPARSE.CuSparseMatrixCSC` #141

Closed hsseung closed 1 year ago

hsseung commented 1 year ago

I'd like my gradient to be a sparse matrix, while my parameters are a dense matrix. This works on CPU, but yields an error on GPU.

using SparseArrays, CUDA, Optimisers
n = 5
k = 2
U = rand(k, n)
grads = sprand(k, n, 0.5)
U = cu(U)
grads = cu(grads)
state = Optimisers.setup(Optimisers.Descent(), U) 
Optimisers.update!(state, U, grads)

Julia 1.8.5, CUDA 12.0, GTX 1080 Ti

ERROR: a exception was thrown during kernel execution.
Stacktrace:
ERROR: a exception was thrown during kernel execution.
Stacktrace:
ERROR: a exception was thrown during kernel execution.
Stacktrace:
ERROR: a exception was thrown during kernel execution.
Stacktrace:
ERROR: a exception was thrown during kernel execution.
Stacktrace:
ERROR: a exception was thrown during kernel execution.
Stacktrace:
ERROR: a exception was thrown during kernel execution.
Stacktrace:
ERROR: a exception was thrown during kernel execution.
Stacktrace:
ERROR: a exception was thrown during kernel execution.
Stacktrace:
ERROR: a exception was thrown during kernel execution.
Stacktrace:
 [1] error_if_canonical_getindex at ./abstractarray.jl:1260
 [1] error_if_canonical_getindex at ./abstractarray.jl:1260
 [1] error_if_canonical_getindex at ./abstractarray.jl:1260
 [1] error_if_canonical_getindex at ./abstractarray.jl:1260
 [1] error_if_canonical_getindex at ./abstractarray.jl:1260
 [1] error_if_canonical_getindex at ./abstractarray.jl:1260
 [1] error_if_canonical_getindex at ./abstractarray.jl:1260
 [1] error_if_canonical_getindex at ./abstractarray.jl:1260
 [1] error_if_canonical_getindex at ./abstractarray.jl:1260
 [1] error_if_canonical_getindex at ./abstractarray.jl:1260
 [2] getindex at ./abstractarray.jl:1240
 [2] getindex at ./abstractarray.jl:1240
 [2] getindex at ./abstractarray.jl:1240
 [2] getindex at ./abstractarray.jl:1240
 [2] getindex at ./abstractarray.jl:1240
 [2] getindex at ./abstractarray.jl:1240
 [2] getindex at ./abstractarray.jl:1240
 [2] getindex at ./abstractarray.jl:1240
 [2] getindex at ./abstractarray.jl:1240
 [2] getindex at ./abstractarray.jl:1240
 [3] _getindex at ./abstractarray.jl:1291
 [3] _getindex at ./abstractarray.jl:1291
 [3] _getindex at ./abstractarray.jl:1291
 [3] _getindex at ./abstractarray.jl:1291
 [3] _getindex at ./abstractarray.jl:1291
 [3] _getindex at ./abstractarray.jl:1291
 [3] _getindex at ./abstractarray.jl:1291
 [3] _getindex at ./abstractarray.jl:1291
 [3] _getindex at ./abstractarray.jl:1291
 [3] _getindex at ./abstractarray.jl:1291
 [4] getindex at ./abstractarray.jl:1241
 [4] getindex at ./abstractarray.jl:1241
 [4] getindex at ./abstractarray.jl:1241
 [4] getindex at ./abstractarray.jl:1241
 [4] getindex at ./abstractarray.jl:1241
 [4] getindex at ./abstractarray.jl:1241
 [4] getindex at ./abstractarray.jl:1241
 [4] getindex at ./abstractarray.jl:1241
 [4] getindex at ./abstractarray.jl:1241
 [4] getindex at ./abstractarray.jl:1241
 [5] _broadcast_getindex at ./broadcast.jl:623
 [5] _broadcast_getindex at ./broadcast.jl:623
 [5] _broadcast_getindex at ./broadcast.jl:623
 [5] _broadcast_getindex at ./broadcast.jl:623
 [5] _broadcast_getindex at ./broadcast.jl:623
 [5] _broadcast_getindex at ./broadcast.jl:623
 [5] _broadcast_getindex at ./broadcast.jl:623
 [5] _broadcast_getindex at ./broadcast.jl:623
 [5] _broadcast_getindex at ./broadcast.jl:623
 [5] _broadcast_getindex at ./broadcast.jl:623
 [6] _getindex at ./broadcast.jl:666
 [6] _getindex at ./broadcast.jl:666
 [6] _getindex at ./broadcast.jl:666
 [6] _getindex at ./broadcast.jl:666
 [6] _getindex at ./broadcast.jl:666
 [6] _getindex at ./broadcast.jl:666
 [6] _getindex at ./broadcast.jl:666
 [6] _getindex at ./broadcast.jl:666
 [6] _getindex at ./broadcast.jl:666
 [6] _getindex at ./broadcast.jl:666
 [7] _broadcast_getindex at ./broadcast.jl:642
 [7] _broadcast_getindex at ./broadcast.jl:642
 [7] _broadcast_getindex at ./broadcast.jl:642
 [7] _broadcast_getindex at ./broadcast.jl:642
 [7] _broadcast_getindex at ./broadcast.jl:642
 [7] _broadcast_getindex at ./broadcast.jl:642
 [7] _broadcast_getindex at ./broadcast.jl:642
 [7] _broadcast_getindex at ./broadcast.jl:642
 [7] _broadcast_getindex at ./broadcast.jl:642
 [7] _broadcast_getindex at ./broadcast.jl:642
 [8] _getindex at ./broadcast.jl:667
 [8] _getindex at ./broadcast.jl:667
 [8] _getindex at ./broadcast.jl:667
 [8] _getindex at ./broadcast.jl:667
 [8] _getindex at ./broadcast.jl:667
 [8] _getindex at ./broadcast.jl:667
 [8] _getindex at ./broadcast.jl:667
 [8] _getindex at ./broadcast.jl:667
 [8] _getindex at ./broadcast.jl:667
 [8] _getindex at ./broadcast.jl:667
 [9] _getindex at ./broadcast.jl:666
 [9] _getindex at ./broadcast.jl:666
 [9] _getindex at ./broadcast.jl:666
 [9] _getindex at ./broadcast.jl:666
 [9] _getindex at ./broadcast.jl:666
 [9] _getindex at ./broadcast.jl:666
 [9] _getindex at ./broadcast.jl:666
 [9] _getindex at ./broadcast.jl:666
 [9] _getindex at ./broadcast.jl:666
 [9] _getindex at ./broadcast.jl:666
 [10] _broadcast_getindex at ./broadcast.jl:642
 [10] _broadcast_getindex at ./broadcast.jl:642
 [10] _broadcast_getindex at ./broadcast.jl:642
 [10] _broadcast_getindex at ./broadcast.jl:642
 [10] _broadcast_getindex at ./broadcast.jl:642
 [10] _broadcast_getindex at ./broadcast.jl:642
 [10] _broadcast_getindex at ./broadcast.jl:642
 [10] _broadcast_getindex at ./broadcast.jl:642
 [10] _broadcast_getindex at ./broadcast.jl:642
 [10] _broadcast_getindex at ./broadcast.jl:642
 [11] getindex at ./broadcast.jl:597
 [11] getindex at ./broadcast.jl:597
 [11] getindex at ./broadcast.jl:597
 [11] getindex at ./broadcast.jl:597
 [11] getindex at ./broadcast.jl:597
 [11] getindex at ./broadcast.jl:597
 [11] getindex at ./broadcast.jl:597
 [11] getindex at ./broadcast.jl:597
 [11] getindex at ./broadcast.jl:597
 [11] getindex at ./broadcast.jl:597
 [12] broadcast_kernel at /usr/people/sseung/.julia/packages/GPUArrays/TnEpb/src/host/broadcast.jl:59
 [12] broadcast_kernel at /usr/people/sseung/.julia/packages/GPUArrays/TnEpb/src/host/broadcast.jl:59
 [12] broadcast_kernel at /usr/people/sseung/.julia/packages/GPUArrays/TnEpb/src/host/broadcast.jl:59
 [12] broadcast_kernel at /usr/people/sseung/.julia/packages/GPUArrays/TnEpb/src/host/broadcast.jl:59
 [12] broadcast_kernel at /usr/people/sseung/.julia/packages/GPUArrays/TnEpb/src/host/broadcast.jl:59
 [12] broadcast_kernel at /usr/people/sseung/.julia/packages/GPUArrays/TnEpb/src/host/broadcast.jl:59
 [12] broadcast_kernel at /usr/people/sseung/.julia/packages/GPUArrays/TnEpb/src/host/broadcast.jl:59
 [12] broadcast_kernel at /usr/people/sseung/.julia/packages/GPUArrays/TnEpb/src/host/broadcast.jl:59
 [12] broadcast_kernel at /usr/people/sseung/.julia/packages/GPUArrays/TnEpb/src/host/broadcast.jl:59
 [12] broadcast_kernel at /usr/people/sseung/.julia/packages/GPUArrays/TnEpb/src/host/broadcast.jl:59
ERROR: CUDA error: unspecified launch failure (code 719, ERROR_LAUNCH_FAILED)
Stacktrace:
 [1] throw_api_error(res::CUDA.cudaError_enum)
   @ CUDA ~/.julia/packages/CUDA/is36v/lib/cudadrv/libcuda.jl:27
 [2] macro expansion
   @ ~/.julia/packages/CUDA/is36v/lib/cudadrv/libcuda.jl:35 [inlined]
 [3] cuCtxSynchronize
   @ ~/.julia/packages/CUDA/is36v/lib/utils/call.jl:26 [inlined]
 [4] nonblocking_synchronize
   @ ~/.julia/packages/CUDA/is36v/lib/cudadrv/context.jl:329 [inlined]
 [5] device_synchronize()
   @ CUDA ~/.julia/packages/CUDA/is36v/lib/cudadrv/context.jl:319
 [6] top-level scope
   @ ~/.julia/packages/CUDA/is36v/src/initialization.jl:164
ToucheSir commented 1 year ago

MWE is just broadcast!(-, U, U, grads) or U .= grads. My guess is that broadcasting from a sparse matrix into a dense one is not supported by CUDA.jl. @maleadt does that sound right to you?

hsseung commented 1 year ago

Yes U .-= grads throws an error, but U -= grads is fine.

After the error, attempts to access grads lead to ERROR: CUDA error: unspecified launch failure (code 719, ERROR_LAUNCH_FAILED) and this is unrecoverable.

maleadt commented 1 year ago

We do actually support sparse broadcast of CSC/CSR matrices involving dense inputs. The problem I think is the in-place version; I only implemented the out-of-place version: https://github.com/JuliaGPU/CUDA.jl/blob/4a29605cbd66c61fbf3ad0727681663a3489f47c/lib/cusparse/broadcast.jl#L462-L600

hsseung commented 1 year ago

Thanks for the info. Guess I'll wait for CUDA.jl to support this.

CarloLucibello commented 1 year ago

It would be good to open an issue in CUDA.jl to track this