jonniedie / ComponentArrays.jl

Arrays with arbitrarily nested named components.
MIT License
286 stars 34 forks source link

Write direct dispatches for `axpy!` & `axpby!` #225

Closed avik-pal closed 9 months ago

avik-pal commented 9 months ago

MWE:

julia> y = ComponentArray(a = cu(zeros(2)), b = cu(zeros(2)))
ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(a = 1:2, b = 3:4)}}}(a = Float32[0.0, 0.0], b = Float32[0.0, 0.0])

julia> x = ComponentArray(a = cu(rand(2)), b = cu(rand(2)))
ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(a = 1:2, b = 3:4)}}}(a = Float32[0.024100939, 0.77158135], b = Float32[0.40818077, 0.44528246])

julia> axpy!(-1, y, x)  # After the patch
ComponentVector{Float32, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(a = 1:2, b = 3:4)}}}, Tuple{Axis{(a = 1:2, b = 3:4)}}}(a = Float32[0.024100939, 0.77158135], b = Float32[0.40818077, 0.44528246])

julia> axpy!(-1, y, x)  # Before the patch for CUDA
ERROR: MethodError: no method matching axpy!(::Int64, ::Float32, ::CuPtr{Float32}, ::Int64, ::CuPtr{Float32}, ::Int64)

Closest candidates are:
  axpy!(::Integer, ::Float32, ::Union{Ptr{Float32}, AbstractArray{Float32}}, ::Integer, ::Union{Ptr{Float32}, AbstractArray{Float32}}, ::Integer)
   @ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/blas.jl:509
  axpy!(::Integer, ::ComplexF32, ::Union{Ptr{ComplexF32}, AbstractArray{ComplexF32}}, ::Integer, ::Union{Ptr{ComplexF32}, AbstractArray{ComplexF32}}, ::Integer)
   @ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/blas.jl:509
  axpy!(::Integer, ::ComplexF64, ::Union{Ptr{ComplexF64}, AbstractArray{ComplexF64}}, ::Integer, ::Union{Ptr{ComplexF64}, AbstractArray{ComplexF64}}, ::Integer)
   @ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/blas.jl:509
  ...

Stacktrace:
 [1] axpy!(alpha::Int64, x::ComponentVector{…}, y::ComponentVector{…})
   @ LinearAlgebra.BLAS ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/blas.jl:522
 [2] axpy!(α::Int64, x::ComponentVector{Float32, CuArray{…}, Tuple{…}}, y::ComponentVector{Float32, CuArray{…}, Tuple{…}})
   @ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/generic.jl:1510
 [3] top-level scope
   @ REPL[46]:1
Some type information was truncated. Use `show(err)` to see complete types.

I will add tests, but unfortunately without CUDA I couldn't reproduce it since JLArrays would also hit the correct BLAS routines.

codecov-commenter commented 9 months ago

Codecov Report

Merging #225 (2730051) into main (df9bd66) will decrease coverage by 0.61%. The diff coverage is 0.00%.

:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.

@@            Coverage Diff             @@
##             main     #225      +/-   ##
==========================================
- Coverage   73.24%   72.64%   -0.61%     
==========================================
  Files          23       23              
  Lines         725      731       +6     
==========================================
  Hits          531      531              
- Misses        194      200       +6     
Files Coverage Δ
src/linear_algebra.jl 78.12% <0.00%> (-18.03%) :arrow_down:

:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more

avik-pal commented 9 months ago

@jonniedie do you want any additional change to this?

jonniedie commented 9 months ago

Nope, looks good. Thanks!

vpuri3 commented 9 months ago

it would be good to add methods for args (::Number, ::ComponentArray, ::Number, ::AbstractArray), and (::Number, AbstractArray, ::Number, ::ComponentArray)