JuliaGPU / CUDA.jl

CUDA programming in Julia.
https://juliagpu.org/cuda/
Other
1.16k stars 206 forks source link

Failure of Eigenvalue Decomposition for Large Matrices. #2413

Open hkimman opened 2 weeks ago

hkimman commented 2 weeks ago

Describe the bug

Before updating, I was able to decompose a large (100,000 by 100,000) dense matrix using the CUDA.CUSOLVER.syevd! with CUDA.jl version 5.2.0. However, after updating to version 5.4.2, the eigenvalue decomposition (CUDA.CUSOLVER.syevd!) fails even with a 15,000 by 15,000 dense matrix, resulting in the following error message:

ERROR: InexactError: trunc(Int32, 2707806276)
Stacktrace:
  [1] throw_inexacterror(f::Symbol, ::Type{Int32}, val::Int64)
    @ Core .\boot.jl:634
  [2] checked_trunc_sint
    @ .\boot.jl:656 [inlined]
  [3] toInt32
    @ .\boot.jl:693 [inlined]
  [4] Int32
    @ .\boot.jl:783 [inlined]
  [5] convert
    @ .\number.jl:7 [inlined]
  [6] cconvert
    @ .\essentials.jl:543 [inlined]
  [7] macro expansion
    @ C:\Users\kimma\.julia\packages\CUDA\75aiI\lib\utils\call.jl:226 [inlined]
  [8] macro expansion
    @ C:\Users\kimma\.julia\packages\CUDA\75aiI\lib\cusolver\libcusolver.jl:3043 [inlined]
  [9] #506
    @ C:\Users\kimma\.julia\packages\CUDA\75aiI\lib\utils\call.jl:35 [inlined]
 [10] retry_reclaim
    @ C:\Users\kimma\.julia\packages\CUDA\75aiI\src\memory.jl:434 [inlined]
 [11] check
    @ C:\Users\kimma\.julia\packages\CUDA\75aiI\lib\cusolver\libcusolver.jl:24 [inlined]
 [12] cusolverDnSsyevd
    @ C:\Users\kimma\.julia\packages\CUDA\75aiI\lib\utils\call.jl:34 [inlined]
 [13] (::CUDA.CUSOLVER.var"#1362#1364"{…})(buffer::CuArray{…})
    @ CUDA.CUSOLVER C:\Users\kimma\.julia\packages\CUDA\75aiI\lib\cusolver\dense.jl:640
 [14] with_workspaces(f::CUDA.CUSOLVER.var"#1362#1364"{…}, cache_gpu::Nothing, cache_cpu::Nothing, size_gpu::CUDA.CUSOLVER.var"#bufferSize#1363"{…}, size_cpu::Int64)
    @ CUDA.APIUtils C:\Users\kimma\.julia\packages\CUDA\75aiI\lib\utils\call.jl:131
 [15] with_workspace
    @ C:\Users\kimma\.julia\packages\CUDA\75aiI\lib\utils\call.jl:67 [inlined]
 [16] syevd!(jobz::Char, uplo::Char, A::CuArray{Float32, 2, CUDA.DeviceMemory})
    @ CUDA.CUSOLVER C:\Users\kimma\.julia\packages\CUDA\75aiI\lib\cusolver\dense.jl:639
 [17] top-level scope
    @ c:\Users\kimma\SynologyDrive\Code\Julia\cuda_bug.jl:5
Some type information was truncated. Use `show(err)` to see complete types.

To reproduce

The Minimal Working Example (MWE) for this bug:

using LinearAlgebra
using CUDA

A = Matrix(Symmetric(rand(15_000,15_000)))
result = CUDA.CUSOLVER.syevd!('V','U',cu(A))
Manifest.toml

``` # This file is machine-generated - editing it directly is not advised julia_version = "1.10.4" manifest_format = "2.0" project_hash = "31a65ef0d76c2eb9948f2f097d929af9853d3bbc" [[deps.CUDA]] deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics"] git-tree-sha1 = "6e945e876652f2003e6ca74e19a3c45017d3e9f6" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" version = "5.4.2" [deps.CUDA.extensions] ChainRulesCoreExt = "ChainRulesCore" EnzymeCoreExt = "EnzymeCore" SpecialFunctionsExt = "SpecialFunctions" [deps.CUDA.weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [[deps.CUDA_Driver_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] git-tree-sha1 = "c48f9da18efd43b6b7adb7ee1f93fe5f2926c339" uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" version = "0.9.0+0" [[deps.CUDA_Runtime_Discovery]] deps = ["Libdl"] git-tree-sha1 = "5db9da5fdeaa708c22ba86b82c49528f402497f2" uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" version = "0.3.3" [[deps.CUDA_Runtime_jll]] deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] git-tree-sha1 = "bcba305388e16aa5c879e896726db9e71b4942c6" uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" version = "0.14.0+1" [[deps.GPUArrays]] deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] git-tree-sha1 = "38cb19b8a3e600e509dc36a6396ac74266d108c1" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" version = "10.1.1" [[deps.GPUArraysCore]] deps = ["Adapt"] git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" uuid = "46192b85-c4d5-4398-a991-12ede77f4527" version = "0.1.6" [[deps.GPUCompiler]] deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"] git-tree-sha1 = "518ebd058c9895de468a8c255797b0c53fdb44dd" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" version = "0.26.5" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] git-tree-sha1 = "389aea28d882a40b5e1747069af71bdbd47a1cae" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" version = "7.2.1" weakdeps = ["BFloat16s"] [deps.LLVM.extensions] BFloat16sExt = "BFloat16s" ```

Version info

Details on Julia:

Julia Version 1.10.4
Commit 48d4fd4843 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 32 × AMD Ryzen 9 7945HX with Radeon Graphics
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 32 virtual cores)
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS =

Details on CUDA:

CUDA runtime 12.5, artifact installation
CUDA driver 12.5
NVIDIA driver 555.99.0

CUDA libraries:
- CUBLAS: 12.5.2
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.2
- CUSPARSE: 12.4.1
- CUPTI: 23.0.0
- NVML: 12.0.0+555.99

Julia packages:
- CUDA: 5.4.2
- CUDA_Driver_jll: 0.9.0+0
- CUDA_Runtime_jll: 0.14.0+1

Toolchain:
- Julia: 1.10.4
- LLVM: 15.0.7

1 device:
  0: NVIDIA GeForce RTX 4090 Laptop GPU (sm_89, 12.148 GiB / 15.992 GiB available)
maleadt commented 2 weeks ago

I don't immediately see how this would have worked before, as the cuSOLVER APIs were always 32-bit. I guess we should be using the new generic APIs here, i.e. Xsyevd instead of Ssyevd, but that's a larger change. The 64-bit APIs were introduced in CUDA 11.1, https://docs.nvidia.com/cuda/archive/11.1.1/cuda-toolkit-release-notes/index.html#cusolver-new-features, and the old ones are now officially deprecated, https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/#cusolver-release-12-4-update-1.

cc @amontoison

amontoison commented 2 weeks ago

@maleadt I planned to wait for CUDA 13.0 to change the dispatch in CUDA.jl. If it's similar to CUSPARSE, the generic routines were buggy until they removed the legacy API (CUDA v12.0). However, what we can do for now is to check the dimension of the matrix and dispatch Xsyevd (and other generic routines) only if the matrix is too large to be supported by the 32-bit API.

if CUSOLVER.version() >= v"11.1" && (n >= 2^31 - 1)
  Xsyevd(...)
else
  syevd(...)
end

For CUDA 13.0, we can just replace the condition by

CUSOLVER.version() >= v"13.0" || (CUSOLVER.version() >= v"11.1" && (n >= 2^31 - 1))
maleadt commented 4 days ago

Related MWE from https://github.com/JuliaGPU/CUDA.jl/issues/2427:

using LinearAlgebra, CUDA
a = CUDA.rand(Float64, 10000, 10000)
b = a + a'
eigen(b)