JuliaGPU / KernelAbstractions.jl

Heterogeneous programming in Julia
MIT License
369 stars 65 forks source link

simple batched dot kernel is ~1.7x slower with Const on Titan RTX #479

Open bjarthur opened 3 months ago

bjarthur commented 3 months ago

could someone please help me understand why this should the the case? PTX code is similar and the threads/blocks are identical.

using KernelAbstractions, CUDA, BenchmarkTools

L, N = 256, 32768
x = CuArray(rand(Float32, L,N));
y = CuArray(rand(Float32, L,N));
o = CuArray(rand(Float32, N));

@kernel function kernel_ka(o, @Const(x), @Const(y))
    k = @index(Global)

    @inbounds begin
        tmp = 0.0f0
        for i=1:size(x,1)
            tmp += x[i,k] * y[i,k]
        end
        o[k] = tmp
    end
end

function kernel_cuda(o, x, y)
    k = threadIdx().x + (blockIdx().x - 1) * blockDim().x

    @inbounds if k<=size(x,2)
        tmp = 0f0
        for i=1:size(x,1)
            tmp += x[i,k] * y[i,k]
        end
        o[k] = tmp
    end
    return nothing
end                    

batched_dot_ka! = kernel_ka(CUDABackend(), 32)
@benchmark CUDA.@sync batched_dot_ka!(o, x, y; ndrange=length(o))

batched_dot_cuda! = @cuda name="batched_dot!" launch=false kernel_cuda(o, x, y)
@benchmark CUDA.@sync batched_dot_cuda!(o, x, y; threads=32, blocks=1024)

the above yields the times below for KA and CUDA, respetively, so KA is ~1.7x slower:

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  382.110 μs …  3.590 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     491.497 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   474.179 μs ± 62.871 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

                           ▆█                                   
  ▁▃▆▆▅▃▄█▆▄▃▃▂▂▂▂▂▂▂▂▄▅▃▂▃████▅▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  382 μs          Histogram: frequency by time          641 μs <

 Memory estimate: 1.72 KiB, allocs estimate: 62.

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  287.471 μs …  3.851 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     294.507 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   301.527 μs ± 47.135 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▄██▆▄▂▂▂                                                  ▂▁ ▂
  ███████████▇▇██▆▆█▆▃▅▁▁▄▃▃▄▃▁▄▁▃▁▁▁▁▁▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▄███ █
  287 μs        Histogram: log(frequency) by time       468 μs <

 Memory estimate: 336 bytes, allocs estimate: 16.

and here is the PTX code:

KA ``` julia> @device_code_ptx batched_dot_ka!(o, x, y; ndrange=length(o)) // PTX CompilerJob of MethodInstance for gpu_kernel_ka(::KernelAbstractions.CompilerMetadata{KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.DynamicCheck, Nothing, CartesianIndices{1, Tuple{Base.OneTo{Int64}}}, KernelAbstractions.NDIteration.NDRange{1, KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.StaticSize{(32,)}, CartesianIndices{1, Tuple{Base.OneTo{Int64}}}, Nothing}}, ::CuDeviceVector{Float32, 1}, ::CuDeviceMatrix{Float32, 1}, ::CuDeviceMatrix{Float32, 1}) for sm_75, maxthreads=32 // // Generated by LLVM NVPTX Back-End // .version 8.3 .target sm_75 .address_size 64 // .globl _Z13gpu_kernel_ka16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_10StaticSizeI5_32__ES2_ILi1ES3_IS4_IS5_EEEvEE13CuDeviceArrayI7Float32Li1ELi1EES8_IS9_Li2ELi1EES8_IS9_Li2ELi1EE // -- Begin function _Z13gpu_kernel_ka16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_10StaticSizeI5_32__ES2_ILi1ES3_IS4_IS5_EEEvEE13CuDeviceArrayI7Float32Li1ELi1EES8_IS9_Li2ELi1EES8_IS9_Li2ELi1EE // @_Z13gpu_kernel_ka16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_10StaticSizeI5_32__ES2_ILi1ES3_IS4_IS5_EEEvEE13CuDeviceArrayI7Float32Li1ELi1EES8_IS9_Li2ELi1EES8_IS9_Li2ELi1EE .visible .entry _Z13gpu_kernel_ka16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_10StaticSizeI5_32__ES2_ILi1ES3_IS4_IS5_EEEvEE13CuDeviceArrayI7Float32Li1ELi1EES8_IS9_Li2ELi1EES8_IS9_Li2ELi1EE( .param .align 8 .b8 _Z13gpu_kernel_ka16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_10StaticSizeI5_32__ES2_ILi1ES3_IS4_IS5_EEEvEE13CuDeviceArrayI7Float32Li1ELi1EES8_IS9_Li2ELi1EES8_IS9_Li2ELi1EE_param_0[16], .param .align 8 .b8 _Z13gpu_kernel_ka16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_10StaticSizeI5_32__ES2_ILi1ES3_IS4_IS5_EEEvEE13CuDeviceArrayI7Float32Li1ELi1EES8_IS9_Li2ELi1EES8_IS9_Li2ELi1EE_param_1[16], .param .align 8 .b8 _Z13gpu_kernel_ka16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_10StaticSizeI5_32__ES2_ILi1ES3_IS4_IS5_EEEvEE13CuDeviceArrayI7Float32Li1ELi1EES8_IS9_Li2ELi1EES8_IS9_Li2ELi1EE_param_2[32], .param .align 8 .b8 _Z13gpu_kernel_ka16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_10StaticSizeI5_32__ES2_ILi1ES3_IS4_IS5_EEEvEE13CuDeviceArrayI7Float32Li1ELi1EES8_IS9_Li2ELi1EES8_IS9_Li2ELi1EE_param_3[40], .param .align 8 .b8 _Z13gpu_kernel_ka16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_10StaticSizeI5_32__ES2_ILi1ES3_IS4_IS5_EEEvEE13CuDeviceArrayI7Float32Li1ELi1EES8_IS9_Li2ELi1EES8_IS9_Li2ELi1EE_param_4[40] ) .maxntid 32, 1, 1 { .reg .pred %p<4>; .reg .b32 %r<4>; .reg .f32 %f<10>; .reg .b64 %rd<32>; // %bb.0: // %conversion ld.param.u64 %rd17, [_Z13gpu_kernel_ka16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_10StaticSizeI5_32__ES2_ILi1ES3_IS4_IS5_EEEvEE13CuDeviceArrayI7Float32Li1ELi1EES8_IS9_Li2ELi1EES8_IS9_Li2ELi1EE_param_1]; mov.u32 %r2, %ctaid.x; mov.u32 %r1, %tid.x; add.s32 %r3, %r1, 1; cvt.u64.u32 %rd18, %r3; mul.wide.u32 %rd6, %r2, 32; add.s64 %rd7, %rd6, %rd18; setp.gt.s64 %p1, %rd7, %rd17; @%p1 bra $L__BB0_5; // %bb.1: // %L95 ld.param.u64 %rd1, [_Z13gpu_kernel_ka16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_10StaticSizeI5_32__ES2_ILi1ES3_IS4_IS5_EEEvEE13CuDeviceArrayI7Float32Li1ELi1EES8_IS9_Li2ELi1EES8_IS9_Li2ELi1EE_param_2]; ld.param.u64 %rd3, [_Z13gpu_kernel_ka16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_10StaticSizeI5_32__ES2_ILi1ES3_IS4_IS5_EEEvEE13CuDeviceArrayI7Float32Li1ELi1EES8_IS9_Li2ELi1EES8_IS9_Li2ELi1EE_param_3+16]; setp.lt.s64 %p2, %rd3, 1; mov.f32 %f9, 0f00000000; @%p2 bra $L__BB0_4; // %bb.2: // %L220.preheader ld.param.u64 %rd2, [_Z13gpu_kernel_ka16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_10StaticSizeI5_32__ES2_ILi1ES3_IS4_IS5_EEEvEE13CuDeviceArrayI7Float32Li1ELi1EES8_IS9_Li2ELi1EES8_IS9_Li2ELi1EE_param_3]; ld.param.u64 %rd4, [_Z13gpu_kernel_ka16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_10StaticSizeI5_32__ES2_ILi1ES3_IS4_IS5_EEEvEE13CuDeviceArrayI7Float32Li1ELi1EES8_IS9_Li2ELi1EES8_IS9_Li2ELi1EE_param_4]; ld.param.u64 %rd5, [_Z13gpu_kernel_ka16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_10StaticSizeI5_32__ES2_ILi1ES3_IS4_IS5_EEEvEE13CuDeviceArrayI7Float32Li1ELi1EES8_IS9_Li2ELi1EES8_IS9_Li2ELi1EE_param_4+16]; max.s64 %rd31, %rd3, 0; add.s64 %rd19, %rd7, 4611686018427387903; mul.lo.s64 %rd20, %rd3, %rd19; max.s64 %rd21, %rd5, 0; mul.lo.s64 %rd22, %rd21, %rd19; shl.b64 %rd23, %rd22, 2; add.s64 %rd30, %rd4, %rd23; shl.b64 %rd24, %rd20, 2; add.s64 %rd29, %rd2, %rd24; mov.f32 %f9, 0f00000000; $L__BB0_3: // %L220 // =>This Inner Loop Header: Depth=1 ld.global.nc.f32 %f6, [%rd29]; ld.global.nc.f32 %f7, [%rd30]; fma.rn.f32 %f9, %f6, %f7, %f9; add.s64 %rd31, %rd31, -1; add.s64 %rd30, %rd30, 4; add.s64 %rd29, %rd29, 4; setp.ne.s64 %p3, %rd31, 0; @%p3 bra $L__BB0_3; $L__BB0_4: // %L381 cvt.u64.u32 %rd25, %r1; add.s64 %rd26, %rd6, %rd25; shl.b64 %rd27, %rd26, 2; add.s64 %rd28, %rd27, %rd1; st.global.f32 [%rd28], %f9; $L__BB0_5: // %L388 ret; // -- End function } ```
CUDA ``` julia> @device_code_ptx @cuda name="batched_dot!" launch=false kernel_cuda(o, x, y) // PTX CompilerJob of MethodInstance for kernel_cuda(::CuDeviceVector{Float32, 1}, ::CuDeviceMatrix{Float32, 1}, ::CuDeviceMatrix{Float32, 1}) for sm_75 // // Generated by LLVM NVPTX Back-End // .version 8.3 .target sm_75 .address_size 64 // .globl batched_dot_ // -- Begin function batched_dot_ // @batched_dot_ .visible .entry batched_dot_( .param .align 8 .b8 batched_dot__param_0[16], .param .align 8 .b8 batched_dot__param_1[32], .param .align 8 .b8 batched_dot__param_2[40], .param .align 8 .b8 batched_dot__param_3[40] ) { .reg .pred %p<6>; .reg .b32 %r<5>; .reg .f32 %f<22>; .reg .b64 %rd<46>; // %bb.0: // %conversion ld.param.u64 %rd25, [batched_dot__param_2+24]; mov.u32 %r1, %tid.x; add.s32 %r2, %r1, 1; mov.u32 %r3, %ctaid.x; mov.u32 %r4, %ntid.x; mul.wide.u32 %rd6, %r3, %r4; cvt.u64.u32 %rd26, %r2; add.s64 %rd7, %rd6, %rd26; setp.gt.s64 %p1, %rd7, %rd25; @%p1 bra $L__BB0_8; // %bb.1: // %L31 ld.param.u64 %rd1, [batched_dot__param_1]; ld.param.u64 %rd3, [batched_dot__param_2+16]; setp.lt.s64 %p2, %rd3, 1; mov.f32 %f19, 0f00000000; @%p2 bra $L__BB0_7; // %bb.2: // %L49.preheader ld.param.u64 %rd2, [batched_dot__param_2]; ld.param.u64 %rd4, [batched_dot__param_3]; ld.param.u64 %rd5, [batched_dot__param_3+16]; max.s64 %rd8, %rd3, 0; add.s64 %rd9, %rd7, -1; mul.lo.s64 %rd10, %rd3, %rd9; max.s64 %rd28, %rd5, 0; mul.lo.s64 %rd11, %rd28, %rd9; and.b64 %rd12, %rd8, 1; setp.eq.s64 %p3, %rd8, 1; mov.f32 %f19, 0f00000000; mov.u64 %rd45, 0; @%p3 bra $L__BB0_5; // %bb.3: // %L49.preheader.new and.b64 %rd13, %rd8, 9223372036854775806; shl.b64 %rd30, %rd11, 2; add.s64 %rd31, %rd30, %rd4; add.s64 %rd43, %rd31, 4; shl.b64 %rd32, %rd10, 2; add.s64 %rd33, %rd32, %rd2; add.s64 %rd42, %rd33, 4; mov.u64 %rd45, 0; mov.f32 %f19, 0f00000000; $L__BB0_4: // %L49 // =>This Inner Loop Header: Depth=1 ld.global.f32 %f11, [%rd42+-4]; ld.global.f32 %f12, [%rd43+-4]; fma.rn.f32 %f13, %f11, %f12, %f19; ld.global.f32 %f14, [%rd42]; ld.global.f32 %f15, [%rd43]; fma.rn.f32 %f19, %f14, %f15, %f13; add.s64 %rd45, %rd45, 2; add.s64 %rd43, %rd43, 8; add.s64 %rd42, %rd42, 8; setp.ne.s64 %p4, %rd13, %rd45; @%p4 bra $L__BB0_4; $L__BB0_5: // %L148.loopexit.unr-lcssa setp.eq.s64 %p5, %rd12, 0; @%p5 bra $L__BB0_7; // %bb.6: // %L49.epil.preheader add.s64 %rd34, %rd45, %rd10; shl.b64 %rd35, %rd34, 2; add.s64 %rd23, %rd2, %rd35; add.s64 %rd36, %rd45, %rd11; shl.b64 %rd37, %rd36, 2; add.s64 %rd24, %rd4, %rd37; ld.global.f32 %f16, [%rd23]; ld.global.f32 %f17, [%rd24]; fma.rn.f32 %f19, %f16, %f17, %f19; $L__BB0_7: // %L148 cvt.u64.u32 %rd38, %r1; add.s64 %rd39, %rd6, %rd38; shl.b64 %rd40, %rd39, 2; add.s64 %rd41, %rd40, %rd1; st.global.f32 [%rd41], %f19; $L__BB0_8: // %L156 ret; // -- End function } ```

vendor-agnostic code is really appealing but i'm not sure i'm willing to pay this much of a performance penalty for it. thanks!

vchuravy commented 3 months ago

Can you post a profile https://cuda.juliagpu.org/stable/development/profiling/#Integrated-profiler so that we can determine if the overhead is in the kernel or the kernel launch.

bjarthur commented 3 months ago

sure!

i made the problem 8-fold bigger in both L and N to emphasize the difference and got:

julia> CUDA.@profile batched_dot_ka!(o, x, y; ndrange=length(o))
Profiler ran for 90.6 ms, capturing 32 events.

Host-side activity: calling CUDA APIs took 616.31 µs (0.68% of the trace)
┌──────────┬────────────┬───────┬───────────────────────────────────────┬────────────────────────┐
│ Time (%) │ Total time │ Calls │ Time distribution                     │ Name                   │
├──────────┼────────────┼───────┼───────────────────────────────────────┼────────────────────────┤
│    0.67% │  607.49 µs │     1 │                                       │ cuLaunchKernel         │
│    0.00% │     3.1 µs │     6 │ 516.57 ns ± 462.72 (   0.0 ‥ 1192.09) │ cuStreamGetCaptureInfo │
└──────────┴────────────┴───────┴───────────────────────────────────────┴────────────────────────┘

Device-side activity: GPU was busy for 35.83 ms (39.55% of the trace)
┌──────────┬────────────┬───────┬──────────────────────────────────────────────────────────────────────────────
│ Time (%) │ Total time │ Calls │ Name                                                                        ⋯
├──────────┼────────────┼───────┼──────────────────────────────────────────────────────────────────────────────
│   39.55% │   35.83 ms │     1 │ _Z13gpu_kernel_ka16CompilerMetadataI11DynamicSize12DynamicCheckv16Cartesian ⋯
└──────────┴────────────┴───────┴──────────────────────────────────────────────────────────────────────────────
                                                                                               1 column omitted

julia> CUDA.@profile batched_dot_cuda!(o, x, y; threads=32, blocks=1024)
Profiler ran for 3.49 ms, capturing 19 events.

Host-side activity: calling CUDA APIs took 621.8 µs (17.79% of the trace)
┌──────────┬────────────┬───────┬──────────────────────────────────────┬────────────────────────┐
│ Time (%) │ Total time │ Calls │ Time distribution                    │ Name                   │
├──────────┼────────────┼───────┼──────────────────────────────────────┼────────────────────────┤
│   17.69% │  618.22 µs │     1 │                                      │ cuLaunchKernel         │
│    0.05% │    1.67 µs │     3 │ 556.31 ns ± 364.19 (238.42 ‥ 953.67) │ cuStreamGetCaptureInfo │
└──────────┴────────────┴───────┴──────────────────────────────────────┴────────────────────────┘

Device-side activity: GPU was busy for 2.85 ms (81.59% of the trace)
┌──────────┬────────────┬───────┬──────────────┐
│ Time (%) │ Total time │ Calls │ Name         │
├──────────┼────────────┼───────┼──────────────┤
│   81.59% │    2.85 ms │     1 │ batched_dot_ │
└──────────┴────────────┴───────┴──────────────┘

so it's definitely the kernel, not the launch.

thanks for the quick reply!

vchuravy commented 3 months ago

If you changed the problem size then you need to change the number of blocks.

julia> CUDA.@profile batched_dot_cuda!(o, x, y; threads=32, blocks=round(Int, length(o)/32))
bjarthur commented 3 months ago

hah, right, how's this:

julia> CUDA.@profile batched_dot_ka!(o, x, y; ndrange=length(o))
Profiler ran for 37.48 ms, capturing 32 events.

Host-side activity: calling CUDA APIs took 603.91 µs (1.61% of the trace)
┌──────────┬────────────┬───────┬──────────────────────────────────────┬────────────────────────┐
│ Time (%) │ Total time │ Calls │ Time distribution                    │ Name                   │
├──────────┼────────────┼───────┼──────────────────────────────────────┼────────────────────────┤
│    1.60% │  599.86 µs │     1 │                                      │ cuLaunchKernel         │
│    0.00% │    1.19 µs │     6 │ 198.68 ns ± 278.72 (   0.0 ‥ 715.26) │ cuStreamGetCaptureInfo │
└──────────┴────────────┴───────┴──────────────────────────────────────┴────────────────────────┘

Device-side activity: GPU was busy for 36.71 ms (97.95% of the trace)
┌──────────┬────────────┬───────┬─────────────────────────────────────────────────────────────────────────
│ Time (%) │ Total time │ Calls │ Name                                                                   ⋯
├──────────┼────────────┼───────┼─────────────────────────────────────────────────────────────────────────
│   97.95% │   36.71 ms │     1 │ _Z13gpu_kernel_ka16CompilerMetadataI11DynamicSize12DynamicCheckv16Cart ⋯
└──────────┴────────────┴───────┴─────────────────────────────────────────────────────────────────────────
                                                                                          1 column omitted

julia> CUDA.@profile batched_dot_cuda!(o, x, y; threads=32, blocks=round(Int, length(o)/32))
Profiler ran for 24.98 ms, capturing 19 events.

Host-side activity: calling CUDA APIs took 606.78 µs (2.43% of the trace)
┌──────────┬────────────┬───────┬──────────────────────────────────────┬────────────────────────┐
│ Time (%) │ Total time │ Calls │ Time distribution                    │ Name                   │
├──────────┼────────────┼───────┼──────────────────────────────────────┼────────────────────────┤
│    2.41% │  602.72 µs │     1 │                                      │ cuLaunchKernel         │
│    0.01% │    1.67 µs │     3 │ 556.31 ns ± 364.19 (238.42 ‥ 953.67) │ cuStreamGetCaptureInfo │
└──────────┴────────────┴───────┴──────────────────────────────────────┴────────────────────────┘

Device-side activity: GPU was busy for 24.35 ms (97.48% of the trace)
┌──────────┬────────────┬───────┬──────────────┐
│ Time (%) │ Total time │ Calls │ Name         │
├──────────┼────────────┼───────┼──────────────┤
│   97.48% │   24.35 ms │     1 │ batched_dot_ │
└──────────┴────────────┴───────┴──────────────┘

launch times still about the same with KA kernel being ~1.5x slower.

vchuravy commented 3 months ago

Ok that is still surprising to me. I expect some overhead but nothing that should scale like that.

vchuravy commented 3 months ago

What is CUDA.versioninfo()

Running this locally on a Quadro RTX 4000:

Device-side activity: GPU was busy for 1.98 ms (10.55% of the trace)
┌──────────┬────────────┬───────┬─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
│ Time (%) │ Total time │ Calls │ Name                                                                                                                                   ⋯
├──────────┼────────────┼───────┼─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
│    5.28% │  992.77 µs │     1 │ _Z13gpu_kernel_ka16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_10Static ⋯
│    5.27% │  991.58 µs │     1 │ batched_dot_                                                                                                                           ⋯
└──────────┴────────────┴───────┴─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

and the 8x bigger case:

Device-side activity: GPU was busy for 120.19 ms (99.46% of the trace)
┌──────────┬────────────┬───────┬─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
│ Time (%) │ Total time │ Calls │ Name                                                                                                                                   ⋯
├──────────┼────────────┼───────┼─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
│   49.74% │    60.1 ms │     1 │ batched_dot_                                                                                                                           ⋯
│   49.72% │   60.08 ms │     1 │ _Z13gpu_kernel_ka16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_10Static ⋯
└──────────┴────────────┴───────┴─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
                                                                                                                                                          1 column omitted
bjarthur commented 3 months ago
julia> CUDA.versioninfo()
CUDA runtime 12.4, artifact installation
CUDA driver 12.3
NVIDIA driver 545.23.8

CUDA libraries: 
- CUBLAS: 12.4.5
- CURAND: 10.3.5
- CUFFT: 11.2.1
- CUSOLVER: 11.6.1
- CUSPARSE: 12.3.1
- CUPTI: 22.0.0
- NVML: 12.0.0+545.23.8

Julia packages: 
- CUDA: 5.4.0
- CUDA_Driver_jll: 0.8.1+0
- CUDA_Runtime_jll: 0.13.0+2

Toolchain:
- Julia: 1.10.4
- LLVM: 15.0.7

2 devices:
  0: NVIDIA TITAN RTX (sm_75, 23.594 GiB / 24.000 GiB available)
  1: NVIDIA TITAN RTX (sm_75, 23.636 GiB / 24.000 GiB available)

i have a variety of other GPUs available to test on if that'd be informative.

vchuravy commented 3 months ago

Okay that makes it even more curious. We are looking at the same generation of GPUs mine should be about 2x slower than yours, which matches.

Could you run the code through https://cuda.juliagpu.org/stable/development/profiling/#NVIDIA-Nsight-Compute / https://docs.nvidia.com/nsight-compute/NsightCompute/index.html

In particular a "Compute Workload Analysis".

Also for completeness, try without the Const:

@kernel function kernel_ka_no_const(o, x, y)
    k = @index(Global)

    @inbounds begin
        tmp = 0.0f0
        for i=1:size(x,1)
            tmp += x[i,k] * y[i,k]
        end
        o[k] = tmp
    end
end
bjarthur commented 3 months ago

it's the Const! removing it yields execution times comparable to CUDA. isn't Const supposed to possibly speed things up?

re. NVIDIA-Nsight's Compute Workload Analysis, there is a 12-27% drop in various metrics for the KA kernel with Const using CUDA as a baseline:

Screenshot 2024-06-10 at 11 53 26 AM
bjarthur commented 3 months ago

curiously, Const doesn't make a difference on a separate machine with an A100 (instead of TITAN RTX) [EDIT-- and the KA and CUDA kernels are equal in speed either way]:

julia> CUDA.versioninfo()
CUDA runtime 12.5, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.2

CUDA libraries: 
- CUBLAS: 12.5.2
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.2
- CUSPARSE: 12.4.1
- CUPTI: 23.0.0
- NVML: 12.0.0+555.42.2

Julia packages: 
- CUDA: 5.4.2
- CUDA_Driver_jll: 0.9.0+0
- CUDA_Runtime_jll: 0.14.0+1

Toolchain:
- Julia: 1.10.4
- LLVM: 15.0.7

1 device:
  0: NVIDIA A100-SXM4-80GB (sm_80, 62.810 GiB / 80.000 GiB available)

i suppose it could be due to minor differences in the CUDA drivers too.

vchuravy commented 3 months ago

Yeah Const ends up as ldg, but it's fascinating that this leads to a performance delta on pro-sumer chips

You can also verify this with using Const and CUDA directly without KA getting involved.

https://github.com/JuliaGPU/CUDA.jl/blob/e1e5be2b6bf17f03a367cebeb18c4645e593f80d/src/device/array.jl#L199

bjarthur commented 3 months ago

I expect some overhead but nothing that should scale like that.

can you please elaborate on why you expect some overhead with KA? in informal testing i now see near parity btw KA and CUDA if the run times are long, but for small inputs KA becomes progressively slower in comparison. just curious why.

vchuravy commented 3 months ago

for small inputs KA becomes progressively slower in comparison. just curious why.

KA adds some additional integer operations for the index calculations and defaults to Int64. Reducing that overhead is a to-do, but I haven't found time for that.

This overhead is more noticeable on AMD.

bjarthur commented 3 months ago

KA adds some additional integer operations for the index calculations and defaults to Int64.

indexing overhead should scale with the problem size (ie input arg dims), no? what i'm seeing seems more like overhead in the kernel launch, as for small problems the difference in run times between KA and CUDA is large whereas the difference is small with large problems.

vchuravy commented 3 months ago

indexing overhead should scale with the problem size (ie input arg dims), no?

Latency hiding becomes more effective at larger problem sizes.

what i'm seeing seems more like overhead in the kernel launch,

That's not infeasible. But the launch code is https://github.com/JuliaGPU/CUDA.jl/blob/e1e5be2b6bf17f03a367cebeb18c4645e593f80d/src/CUDAKernels.jl#L89 which itself is fairly minimal.

bjarthur commented 3 months ago

is the indexing overhead in @index, or is it incurred for each indexing operation that the kernel performs on the input args (e.g. x[i,k] in the kernels above)?

bjarthur commented 3 months ago

using static ranges mitigates some of the performance gap i see between KA and CUDA for small problems. see https://github.com/JuliaGPU/KernelAbstractions.jl/issues/470

vchuravy commented 3 months ago

is the indexing overhead in @index

Yes. And it should be CSE'd as you noted constant ndra he's can help as well.