JuliaGPU / CUDA.jl

CUDA programming in Julia.
https://juliagpu.org/cuda/
Other
1.2k stars 218 forks source link

Make generic_trimatmul more specific #2359

Closed tgymnich closed 4 months ago

tgymnich commented 5 months ago

I am currently implementing a generic version of generic_trimatmul https://github.com/JuliaGPU/GPUArrays.jl/pull/531.

The GPUArrays test suite keeps failing for CUDA, AMDGPU and oneAPI though: https://buildkite.com/julialang/gpuarrays-dot-jl/builds/861#018f3433-15d5-4251-8a39-9a37d194eb22

Some tests did not pass: 347 passed, 0 failed, 24 errored, 0 broken.
gpuarrays/linalg: Error During Test at /var/lib/buildkite-agent/builds/gpuci-4/julialang/gpuarrays-dot-jl/test/testsuite/linalg.jl:138
  Got exception outside of a @test
  MethodError: generic_trimatmul!(::CuArray{Float32, 1, CUDA.DeviceMemory}, ::Char, ::Char, ::typeof(identity), ::CuArray{Float32, 2, CUDA.DeviceMemory}, ::CuArray{Float32, 1, CUDA.DeviceMemory}) is ambiguous.
  Candidates:
    generic_trimatmul!(C::GPUArraysCore.AbstractGPUVecOrMat, uploc, isunitc, tfun::Function, A::GPUArraysCore.AbstractGPUMatrix, B::GPUArraysCore.AbstractGPUVecOrMat)
      @ GPUArrays /var/lib/buildkite-agent/builds/gpuci-4/julialang/gpuarrays-dot-jl/src/host/linalg.jl:506
    generic_trimatmul!(c::StridedCuVector{T}, uploc, isunitc, tfun::Function, A::StridedCuMatrix{T}, b::AbstractVector{T}) where T<:Union{Float32, Float64, ComplexF64, ComplexF32}
      @ CUDA.CUBLAS ~/.cache/julia-buildkite-plugin/depots/c9f52312-b528-44e4-9501-6d408762012b/dev/CUDA/lib/cublas/linalg.jl:224
  Possible fix, define
    generic_trimatmul!(::CuArray{T, 1}, ::Any, ::Any, ::Function, ::CuArray{T, 2}, ::GPUArraysCore.AbstractGPUVector{T}) where T<:Union{Float32, Float64, ComplexF64, ComplexF32}

So I would like to make the function signature of generic_trimatmul in CUDA.jl a bit more specific to be inline with trsv! to which generic_trimatmul dispatches (for vectors).

function trsv!(uplo::Char, trans::Char, diag::Char, A::StridedCuMatrix{$elty}, x::StridedCuVector{$elty})
maleadt commented 5 months ago

cc @dkarrasch

dkarrasch commented 4 months ago

Seems fine to me, if that's the requirement of tr[s/m]v anyway.