JuliaGPU / CUDA.jl

CUDA programming in Julia.
https://juliagpu.org/cuda/
Other
1.2k stars 218 forks source link

Add support for arbitrary group sizes in `gemm_grouped_batched!` #2334

Closed lpawela closed 2 months ago

lpawela commented 5 months ago

Currently group_size is hardcoded to ones. This adds support for arbitrary group sizes. Also this changes the input types from A::Vector{<:StridedCuMatrix{T}} to A::Vector{<:Vector{<:StridedCuMatrix{T}}}.

maleadt commented 5 months ago

This is a breaking change, so I'll let the original author review.

amontoison commented 5 months ago

Is it possible to keep the initial function with A::Vector{<:StridedCuMatrix{T}}? You will avoid a breaking change by doing that.

The function with A::Vector{<:Vector{<:StridedCuMatrix{T}}} as input is relevant but the users must check what are the blocks of the same size. It should be an extension of the previous function.

lpawela commented 5 months ago

I can always restore the older version. But doesn't that just duplicate the behavior of gemm_batched! with some extra steps? In this case the switch would be relatively simple - just change some function names.

amontoison commented 5 months ago

I can always restore the older version. But doesn't that just duplicate the behavior of gemm_batched! with some extra steps? In this case the switch would be relatively simple - just change some function names.

Why not rely on the multiple dispatch of gemm_batched!?

lpawela commented 3 months ago

I restored the previous implementation alongside mine. @maleadt is this sufficient or should I just overload gemm_batched!?

maleadt commented 2 months ago

Seems fine to me; I'll let @amontoison give the final OK though.

amontoison commented 2 months ago

LGTM!

maleadt commented 2 months ago

Weirdly, Enzyme.jl tests only seem to fail on this PR, even though I don't think gemm_grouped_batched is used anywhere?

lpawela commented 2 months ago

Strange, on Enzyme.jl w v0.12.22 the test/cuda.jl tests pass for me with this PR

(Enzyme) pkg> st
Project Enzyme v0.12.22
Status `~/lib/Enzyme.jl/Project.toml`
  [fa961155] CEnum v0.5.0
  [052768ef] CUDA v5.4.2 `../CUDA.jl`
  [f151be2c] EnzymeCore v0.7.6
  [61eb1bfa] GPUCompiler v0.26.7
  [929cbde3] LLVM v8.0.0
  [d8793406] ObjectFile v0.4.1
  [21216c6a] Preferences v1.4.3
  [7cc45869] Enzyme_jll v0.0.133+0
  [8f399da3] Libdl
  [37e2e46d] LinearAlgebra
  [de0858da] Printf
  [9a3f8284] Random

On the main branch this also passes. The entire test suite also passes on v0.12.22. On the main branch the test suite fails with

┌ Warning: Using fallback BLAS replacements for (["dasum_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Using fallback BLAS replacements for (["dasum_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
Test Summary: | Pass  Total   Time
DiffTest      |   44     44  41.7s
2.3
2.3
Test Summary: | Pass  Total  Time
IO            |    4      4  4.4s
Test Summary: |Time
hmlstm        | None  0.1s
Test Summary:  |Time
No speculation | None  0.3s
ERROR: Package Enzyme errored during testing (received signal: 11)
maleadt commented 2 months ago

@wsmoses Can you look into this CI failure? It's pretty inscrutable to me.

wsmoses commented 2 months ago

Oh I think you just got unlucky. That was fixed almost immediately after in Enzyme

On Fri, Jul 12, 2024 at 8:33 AM Tim Besard @.***> wrote:

@wsmoses https://github.com/wsmoses Can you look into this CI failure? It's pretty inscrutable to me.

— Reply to this email directly, view it on GitHub https://github.com/JuliaGPU/CUDA.jl/pull/2334#issuecomment-2225483562, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXEBK6XW33UWCVCPX7TZL7EJPAVCNFSM6AAAAABGNVEMIGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMRVGQ4DGNJWGI . You are receiving this because you were mentioned.Message ID: @.***>