JuliaGPU / GPUArrays.jl

Reusable array functionality for Julia's various GPU backends.
MIT License
334 stars 78 forks source link

Reduce number of `mul!` methods #472

Closed dkarrasch closed 1 year ago

dkarrasch commented 1 year ago

In the context of https://github.com/JuliaGPU/CUDA.jl/pull/1904 I realized that between LinearAlgebra and CUDA there is yet another layer, which is this GPUArrays.jl. I believe that if we hook into the method hierarchy one level lower, many things will greatly simplify. I don't know if this should be annotated with @inline so that potentially constant propagation could eliminate the character fiddling. Comments welcome!

maleadt commented 1 year ago

Seeing the Metal.jl CI failure, I guess this is a breaking change?

dkarrasch commented 1 year ago

Sorry for seemingly abandoning this PR. I was working on https://github.com/JuliaLang/julia/pull/49806, and now I'm working on including HermOrSym wrappers into that mechanism. Once that is done, I'll continue with SparseArrays.jl, so that these changes are included in Julia v1.10. Afterwards, I'll return to this one.

Regarding your question, this PR by itself is breaking, that is correct. But we can preempt breakage by introducing a few methods in Metal.jl first. Eventually, starting with Julia v1.10, there will be only one method overload that will handle multiplication by once-wrapped GPUArrays, where the wrappers include Adjoint, Transpose, Hermitian and Symmetric.

maleadt commented 1 year ago

Sorry for seemingly abandoning this PR. I was working on JuliaLang/julia#49806, and now I'm working on including HermOrSym wrappers into that mechanism. Once that is done, I'll continue with SparseArrays.jl, so that these changes are included in Julia v1.10. Afterwards, I'll return to this one.

Of course, no problem; thanks for doing this!

maleadt commented 1 year ago

Ah, a new issue with Metal.jl; a bad convert method is being called:

  MethodError: convert(::Type{Union{}}, ::MtlMatrix{ComplexF32}) is ambiguous. Candidates:
    convert(T::Type{<:SparseArrays.AbstractSparseMatrixCSC}, m::AbstractMatrix) in SparseArrays at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-macmini-aarch64-2.0/julia_installs/bin/mac/aarch64/1.8/julia-1.8-latest-macaarch64/share/julia/stdlib/v1.8/SparseArrays/src/sparsematrix.jl:745
    convert(T::Type{<:LinearAlgebra.Bidiagonal}, m::AbstractMatrix) in LinearAlgebra at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-macmini-aarch64-2.0/julia_installs/bin/mac/aarch64/1.8/julia-1.8-latest-macaarch64/share/julia/stdlib/v1.8/LinearAlgebra/src/bidiag.jl:203
    convert(::Type{Union{}}, a::AbstractArray) in Base at array.jl:618
    convert(::Type{T}, a::AbstractArray) where T<:GPUArraysCore.AbstractGPUArray in GPUArrays at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-macmini-aarch64-2.0/build/default-macmini-aarch64-2-0/julialang/gpuarrays-dot-jl/src/host/construction.jl:4
    convert(T::Type{<:BitArray}, a::AbstractArray) in Base at bitarray.jl:580
    convert(::Type{T}, a::AbstractArray) where T<:Array in Base at array.jl:617
    convert(::Type{SA}, a::AbstractArray) where SA<:StaticArraysCore.StaticArray in StaticArrays at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-macmini-aarch64-2.0/depots/c9f52312-b528-44e4-9501-6d408762012b/packages/StaticArrays/J9itA/src/convert.jl:194
    convert(::Type{Union{}}, x) in Base at essentials.jl:213
    convert(::Type{T}, arg) where T<:VecElement in Base at baseext.jl:19
  Possible fix, define
    convert(::Type{Union{}}, ::AbstractMatrix)
  Stacktrace:
   [1] to_power_type(x::MtlMatrix{ComplexF32})
     @ Base ./intfuncs.jl:250
maleadt commented 1 year ago
/AppleInternal/Library/BuildRoots/9941690d-bcf7-11ed-a645-863efbbaf80d/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSMatrix/LinearAlgebra/ARM64/MPSMatrixMultiplication.mm:3274: failed assertion `Number of requested rows in left input matrix exceeds left input matrix size.'

Interesting; I guess Metal.jl should protect against that to avoid an abort.

maleadt commented 1 year ago

Finally, all green. Let's tag things to get this out there!