Open maleadt opened 1 month ago
Hmm, this is a little disconcerting: Even a very simple Cartesian kernel inhibits a very significant slowdown.
Reduced from our broadcast
implementation:
using CUDA, KernelAbstractions, Chairmarks
function main()
a = CuArray{Float32}(undef, 512, 1000)
bc = Broadcast.broadcasted(identity, 0f0)
bc = Broadcast.instantiate(Broadcast.Broadcasted(bc.f, bc.args, axes(a)))
print("Old: ")
display(@b CUDA.@sync copyto_old!(a, bc))
print("New: ")
display(@b CUDA.@sync copyto_new!(a, bc))
end
@inline function copyto_old!(dest::AbstractArray, bc)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest
bc = Broadcast.preprocess(dest, bc)
function broadcast_kernel(dest, bc)
i = (blockIdx().x-1) * blockDim().x + threadIdx().x
@inbounds if i <= length(dest)
I = CartesianIndices(dest)[i]
dest[I] = bc[I]
end
return
end
kernel = @cuda launch=false broadcast_kernel(dest, bc)
config = launch_configuration(kernel.fun)
threads = min(length(dest), config.threads)
blocks = cld(length(dest), threads)
kernel(dest, bc; threads, blocks)
return dest
end
@inline function copyto_new!(dest::AbstractArray, bc)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest
bc = Broadcast.preprocess(dest, bc)
@kernel function broadcast_kernel(dest, bc)
I = @index(Global, Cartesian)
@inbounds dest[I] = bc[I]
end
broadcast_kernel(get_backend(dest))(dest, bc; ndrange=size(dest))
return dest
end
Without CUDA.@sync
, i.e., measuring launch overhead:
julia> main()
Old: 3.199 μs (12 allocs: 240 bytes)
New: 5.192 μs (58 allocs: 1.859 KiB)
With CUDA.@sync
, i.e., measuring execution time:
julia> main()
Old: 7.746 μs (12 allocs: 240 bytes)
New: 10.940 μs (58 allocs: 1.859 KiB)
The overhead scales, e.g., using 4k x 4k inputs instead:
julia> main()
Old: 30.230 μs (12 allocs: 240 bytes)
New: 61.250 μs (58 allocs: 1.859 KiB)
Generated code looks pretty bad, with both extra exceptions, branches, and argument mangling:
define ptx_kernel void @old({ i64, i32 } %state, { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, { [1 x float], [2 x [1 x i64]] } %1) local_unnamed_addr {
conversion:
%.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 0
%.fca.3.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 3
%.fca.0.0.extract = extractvalue { [1 x float], [2 x [1 x i64]] } %1, 0, 0
%2 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%3 = zext i32 %2 to i64
%4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%5 = zext i32 %4 to i64
%6 = mul nuw nsw i64 %3, %5
%7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%8 = add nuw nsw i32 %7, 1
%9 = zext i32 %8 to i64
%10 = add nuw nsw i64 %6, %9
%.not = icmp sgt i64 %10, %.fca.3.extract
br i1 %.not, label %L165, label %L30
L30: ; preds = %conversion
%.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 0
%.not6 = icmp eq i64 %.fca.2.0.extract, 0
br i1 %.not6, label %fail, label %pass
L165: ; preds = %pass, %conversion
ret void
fail: ; preds = %L30
call fastcc void @gpu_report_exception({ i64, i32 } %state, i64 ptrtoint ([10 x i8]* @exception16 to i64))
call fastcc void @gpu_signal_exception({ i64, i32 } %state)
call void @llvm.trap()
call void @llvm.trap()
call void asm sideeffect "exit;", ""()
unreachable
pass: ; preds = %L30
%11 = add nsw i64 %10, -1
%12 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
%13 = getelementptr inbounds float, float addrspace(1)* %12, i64 %11
store float %.fca.0.0.extract, float addrspace(1)* %13, align 4
br label %L165
}
define ptx_kernel void @new({ i64, i32 } %state, { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, { [1 x float], [2 x [1 x i64]] } %2) local_unnamed_addr {
conversion:
%.fca.0.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 0, 0, 0, 0
%.fca.0.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 0, 0, 1, 0
%.fca.1.0.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 0, 0, 0, 0
%.fca.1.1.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 1, 0, 0, 0
%.fca.1.1.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 1, 0, 1, 0
%.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 0
%.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 2, 0
%.fca.0.0.extract = extractvalue { [1 x float], [2 x [1 x i64]] } %2, 0, 0
%.not = icmp eq i64 %.fca.1.0.0.0.0.extract, 0
br i1 %.not, label %fail, label %pass
L527: ; preds = %pass6, %pass2
ret void
fail: ; preds = %conversion
call fastcc void @gpu_report_exception({ i64, i32 } %state, i64 ptrtoint ([10 x i8]* @exception19 to i64))
call fastcc void @gpu_signal_exception({ i64, i32 } %state)
call void @llvm.trap()
call void @llvm.trap()
call void asm sideeffect "exit;", ""()
unreachable
pass: ; preds = %conversion
%.not15 = icmp eq i64 %.fca.1.1.0.0.0.extract, 0
br i1 %.not15, label %fail1, label %pass2
fail1: ; preds = %pass
call fastcc void @gpu_report_exception({ i64, i32 } %state, i64 ptrtoint ([10 x i8]* @exception19 to i64))
call fastcc void @gpu_signal_exception({ i64, i32 } %state)
call void @llvm.trap()
call void @llvm.trap()
call void asm sideeffect "exit;", ""()
unreachable
pass2: ; preds = %pass
%3 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%4 = zext i32 %3 to i64
%5 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%6 = zext i32 %5 to i64
%7 = sdiv i64 %6, %.fca.1.0.0.0.0.extract
%.neg16 = mul i64 %7, %.fca.1.0.0.0.0.extract
%8 = sdiv i64 %4, %.fca.1.1.0.0.0.extract
%9 = add nsw i64 %8, 1
%reass.add17 = add i64 %.neg16, %8
%reass.add = sub i64 %6, %reass.add17
%reass.mul = mul i64 %reass.add, %.fca.1.1.0.0.0.extract
%10 = add nuw nsw i64 %4, 1
%11 = add i64 %10, %reass.mul
%12 = mul i64 %7, %.fca.1.1.0.1.0.extract
%13 = add i64 %9, %12
%14 = icmp sgt i64 %11, 0
%15 = icmp sle i64 %11, %.fca.0.0.0.0.extract
%16 = and i1 %14, %15
%17 = icmp sgt i64 %13, 0
%18 = icmp sle i64 %13, %.fca.0.0.1.0.extract
%19 = and i1 %17, %18
%20 = and i1 %19, %16
br i1 %20, label %pass6, label %L527
pass6: ; preds = %pass2
%21 = add i64 %12, %8
%22 = mul i64 %21, %.fca.2.0.extract
%23 = add i64 %22, %4
%24 = add i64 %23, %reass.mul
%25 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
%26 = getelementptr inbounds float, float addrspace(1)* %25, i64 %24
store float %.fca.0.0.extract, float addrspace(1)* %26, align 4
br label %L527
}
This results in a much higher register usage, from 28 to 64 (at the PTX level).
cc @vchuravy
Looks like most of the added code comes from KA's nditeration handlng:
pass2: ; preds = %pass
; │└└└└└└└└└
; │┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:92 within `#threadIdx`
; ││┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:46 within `threadIdx_x`
; │││┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:7 within `_index`
; ││││┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:7 within `macro expansion` @ /home/tim/.julia/packages/LLVM/joxPv/src/interop/base.jl:39
%11 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
; │└└└└
; │┌ @ /home/tim/.julia/packages/KernelAbstractions/491pi/src/nditeration.jl:84 within `expand`
; ││┌ @ abstractarray.jl:1312 within `getindex`
; │││┌ @ abstractarray.jl:1353 within `_getindex`
; ││││┌ @ abstractarray.jl:1360 within `_to_subscript_indices`
; │││││┌ @ abstractarray.jl:1382 within `_unsafe_ind2sub`
; ││││││┌ @ abstractarray.jl:3053 within `_ind2sub` @ abstractarray.jl:3091
; │││││││┌ @ int.jl:86 within `-`
%12 = zext i32 %11 to i64
; │└└└└└└└
; │┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:78 within `#blockIdx`
; ││┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:56 within `blockIdx_x`
; │││┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:7 within `_index`
; ││││┌ @ /home/tim/Julia/pkg/CUDA/src/device/intrinsics/indexing.jl:7 within `macro expansion` @ /home/tim/.julia/packages/LLVM/joxPv/src/interop/base.jl:39
%13 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
; │└└└└
; │┌ @ /home/tim/.julia/packages/KernelAbstractions/491pi/src/nditeration.jl:84 within `expand`
; ││┌ @ abstractarray.jl:1312 within `getindex`
; │││┌ @ abstractarray.jl:1353 within `_getindex`
; ││││┌ @ abstractarray.jl:1360 within `_to_subscript_indices`
; │││││┌ @ abstractarray.jl:1382 within `_unsafe_ind2sub`
; ││││││┌ @ abstractarray.jl:3053 within `_ind2sub` @ abstractarray.jl:3091
; │││││││┌ @ int.jl:86 within `-`
%14 = zext i32 %13 to i64
; │││││││└
; │││││││┌ @ abstractarray.jl:3104 within `_ind2sub_recurse`
; ││││││││┌ @ abstractarray.jl:3111 within `_div`
; │││││││││┌ @ int.jl:295 within `div`
%15 = sdiv i64 %14, %.fca.1.0.0.0.0.extract
%.neg17 = mul i64 %15, %.fca.1.0.0.0.0.extract
%16 = sdiv i64 %12, %.fca.1.1.0.0.0.extract
; ││││││││└└
; ││││││││ @ abstractarray.jl:3105 within `_ind2sub_recurse` @ abstractarray.jl:3099
; ││││││││┌ @ abstractarray.jl:3109 within `_lookup`
; │││││││││┌ @ int.jl:87 within `+`
%17 = add nsw i64 %16, 1
%reass.add18 = add i64 %.neg17, %16
%reass.add = sub i64 %14, %reass.add18
%reass.mul = mul i64 %reass.add, %.fca.1.1.0.0.0.extract
; ││││││││└└
; ││││││││ @ abstractarray.jl:3105 within `_ind2sub_recurse`
; ││││││││┌ @ int.jl:87 within `+`
%18 = add nuw nsw i64 %12, 1
; ││└└└└└└└
; ││ @ /home/tim/.julia/packages/KernelAbstractions/491pi/src/nditeration.jl:84 within `expand` @ /home/tim/.julia/packages/KernelAbstractions/491pi/src/nditeration.jl:74
; ││┌ @ ntuple.jl:49 within `ntuple`
; │││┌ @ /home/tim/.julia/packages/KernelAbstractions/491pi/src/nditeration.jl:78 within `#1`
; ││││┌ @ int.jl:87 within `+`
%19 = add i64 %18, %reass.mul
; ││││└
; ││││┌ @ int.jl:88 within `*`
%20 = mul i64 %15, %.fca.1.1.0.1.0.extract
; ││││└
; ││││┌ @ int.jl:87 within `+`
%21 = add i64 %17, %20
; │└└└└
; │ @ /home/tim/Julia/pkg/CUDA/src/CUDAKernels.jl:168 within `#__validindex`
; │┌ @ multidimensional.jl:477 within `in`
; ││┌ @ tuple.jl:383 within `map`
; │││┌ @ range.jl:1426 within `in`
; ││││┌ @ int.jl:514 within `<=`
%22 = icmp sgt i64 %19, 0
%23 = icmp sle i64 %19, %.fca.0.0.0.0.extract
; ││││└
; ││││┌ @ bool.jl:38 within `&`
%24 = and i1 %22, %23
; ││││└
; ││││┌ @ int.jl:514 within `<=`
%25 = icmp sgt i64 %21, 0
%26 = icmp sle i64 %21, %.fca.0.0.1.0.extract
; ││││└
; ││││┌ @ bool.jl:38 within `&`
%27 = and i1 %25, %26
; ││└└└
; ││┌ @ tuple.jl:664 within `all`
; │││┌ @ bool.jl:38 within `&`
%28 = and i1 %27, %24
; └└└└
br i1 %28, label %L242, label %L538
A couple of other things that stand out:
sdiv
sblockDim
, while KA.jl somehow doesn't (I guess it never computes a global linear index?)Testing on https://github.com/JuliaGPU/KernelAbstractions.jl/pull/518, a bit of performance is recovered, but it remains bad:
Old: 29.800 μs (12 allocs: 240 bytes)
New: 56.749 μs (58 allocs: 1.859 KiB)
Updated MWE:
using CUDA, KernelAbstractions, Chairmarks
using LLVM, LLVM.Interop
function main()
a = CuArray{Float32}(undef, 4000, 4000)
bc = Broadcast.broadcasted(identity, 0f0)
bc = Broadcast.instantiate(Broadcast.Broadcasted(bc.f, bc.args, axes(a)))
print("Old: ")
display(@b CUDA.@sync copyto_old!(a, bc))
print("New: ")
display(@b CUDA.@sync copyto_new!(a, bc))
end
@inline function copyto_old!(dest::AbstractArray, bc)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest
bc = Broadcast.preprocess(dest, bc)
function broadcast_kernel(dest, bc)
i = (blockIdx().x-1) * blockDim().x + threadIdx().x
assume.(size(dest) .> 0)
@inbounds if i <= length(dest)
I = CartesianIndices(dest)[i]
dest[I] = bc[I]
end
return
end
kernel = @cuda launch=false broadcast_kernel(dest, bc)
config = launch_configuration(kernel.fun)
threads = min(length(dest), config.threads)
blocks = cld(length(dest), threads)
kernel(dest, bc; threads, blocks)
return dest
end
@inline function copyto_new!(dest::AbstractArray, bc)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest
bc = Broadcast.preprocess(dest, bc)
@kernel function broadcast_kernel(dest, bc)
I = @index(Global, Cartesian)
@inbounds dest[I] = bc[I]
end
broadcast_kernel(get_backend(dest))(dest, bc; ndrange=size(dest))
return dest
end
Old:
define ptx_kernel void @_Z16broadcast_kernel13CuDeviceArrayI7Float32Li2ELi1EE11BroadcastedI17DefaultArrayStyleILi0EE5TupleI5OneToI5Int64ES8_E8identityS5_IS0_EE({ i64, i32 } %state, { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, { [1 x float], [2 x [1 x i64]] } %1) local_unnamed_addr {
conversion:
%.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 0
%.fca.2.1.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 1
%.fca.3.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 3
%2 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%3 = zext i32 %2 to i64
%4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%5 = zext i32 %4 to i64
%6 = mul nuw nsw i64 %3, %5
%7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%8 = add nuw nsw i32 %7, 1
%9 = zext i32 %8 to i64
%10 = add nuw nsw i64 %6, %9
%11 = icmp sgt i64 %.fca.2.0.extract, 0
call void @llvm.assume(i1 %11)
%12 = icmp sgt i64 %.fca.2.1.extract, 0
call void @llvm.assume(i1 %12)
%.not = icmp sgt i64 %10, %.fca.3.extract
br i1 %.not, label %L176, label %pass
L176: ; preds = %pass, %conversion
ret void
pass: ; preds = %conversion
%.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 0
%13 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
%14 = add nsw i64 %10, -1
%15 = getelementptr inbounds float, float addrspace(1)* %13, i64 %14
%.fca.0.0.extract = extractvalue { [1 x float], [2 x [1 x i64]] } %1, 0, 0
store float %.fca.0.0.extract, float addrspace(1)* %15, align 4
br label %L176
}
New:
define ptx_kernel void @_Z20gpu_broadcast_kernel16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi2E5TupleI5OneToI5Int64ES6_EE7NDRangeILi2ES0_S0_S8_S8_EE13CuDeviceArrayI7Float32Li2ELi1EE11BroadcastedI17DefaultArrayStyleILi0EES7_8identityS3_ISD_EE({ i64, i32 } %state, { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, { [1 x float], [2 x [1 x i64]] } %2) local_unnamed_addr {
conversion:
%.fca.0.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 0, 0, 0, 0
%.fca.0.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 0, 0, 1, 0
%.fca.1.0.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 0, 0, 0, 0
%.fca.1.0.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 0, 0, 1, 0
%.fca.1.1.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 1, 0, 0, 0
%.fca.1.1.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 1, 0, 1, 0
%3 = icmp sgt i64 %.fca.1.0.0.0.0.extract, 0
call void @llvm.assume(i1 %3)
%4 = icmp sgt i64 %.fca.1.0.0.1.0.extract, 0
call void @llvm.assume(i1 %4)
%5 = icmp sgt i64 %.fca.1.1.0.0.0.extract, 0
call void @llvm.assume(i1 %5)
%6 = icmp sgt i64 %.fca.1.1.0.1.0.extract, 0
call void @llvm.assume(i1 %6)
%7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%8 = zext i32 %7 to i64
%9 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%10 = zext i32 %9 to i64
%11 = udiv i64 %10, %.fca.1.0.0.0.0.extract
%.neg22 = mul i64 %11, %.fca.1.0.0.0.0.extract
%12 = udiv i64 %8, %.fca.1.1.0.0.0.extract
%13 = add nuw nsw i64 %12, 1
%reass.add23 = add i64 %.neg22, %12
%reass.add = sub i64 %10, %reass.add23
%reass.mul = mul i64 %reass.add, %.fca.1.1.0.0.0.extract
%14 = add nuw nsw i64 %8, 1
%15 = add i64 %14, %reass.mul
%16 = mul i64 %11, %.fca.1.1.0.1.0.extract
%17 = add i64 %13, %16
%18 = icmp sgt i64 %15, 0
%19 = icmp sle i64 %15, %.fca.0.0.0.0.extract
%20 = and i1 %18, %19
%21 = icmp sgt i64 %17, 0
%22 = icmp sle i64 %17, %.fca.0.0.1.0.extract
%23 = and i1 %21, %22
%24 = and i1 %23, %20
br i1 %24, label %pass6, label %L585
L585: ; preds = %pass6, %conversion
ret void
pass6: ; preds = %conversion
%.fca.0.0.extract = extractvalue { [1 x float], [2 x [1 x i64]] } %2, 0, 0
%.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 2, 0
%.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 0
%25 = add i64 %16, %12
%26 = mul i64 %25, %.fca.2.0.extract
%27 = add i64 %26, %8
%28 = add i64 %27, %reass.mul
%29 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
%30 = getelementptr inbounds float, float addrspace(1)* %29, i64 %28
store float %.fca.0.0.extract, float addrspace(1)* %30, align 4
br label %L585
}
The added arithmetic instructions are very clear now, and the udiv
s are probably still the culprit.
It's curious that the old version doesn't have any integer divisions to convert the linear index into a global one. Presumably this is possible because the array style of the broadcast is 0D (we're broadcasting a simple scalar), and when using KA.jl that somehow gets lost.
Indeed, when broadcasting an actual array (where the old code would have to udiv
too), the performance is much closer:
function main()
a = CuArray{Float32}(undef, 4000, 4000)
b = CuArray{Float32}(undef, 4000, 4000)
bc = Broadcast.broadcasted(identity, b)
bc = Broadcast.instantiate(Broadcast.Broadcasted(bc.f, bc.args, axes(a)))
print("Old: ")
display(@b CUDA.@sync copyto_old!(a, bc))
print("New: ")
display(@b CUDA.@sync copyto_new!(a, bc))
end
Old: 153.629 μs (42 allocs: 720 bytes)
New: 161.238 μs (88 allocs: 2.547 KiB)
Old:
define ptx_kernel void @_Z16broadcast_kernel13CuDeviceArrayI7Float32Li2ELi1EE11BroadcastedI12CuArrayStyleILi2E12DeviceMemoryE5TupleI5OneToI5Int64ES9_E8identityS6_I8ExtrudedIS1_S6_I4BoolSD_ES6_IS8_S8_EEEE({ i64, i32 } %state, { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1) local_unnamed_addr {
conversion:
%.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 0
%.fca.2.1.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 1
%.fca.3.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 3
%2 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%3 = zext i32 %2 to i64
%4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%5 = zext i32 %4 to i64
%6 = mul nuw nsw i64 %3, %5
%7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%8 = add nuw nsw i32 %7, 1
%9 = zext i32 %8 to i64
%10 = add nuw nsw i64 %6, %9
%11 = icmp sgt i64 %.fca.2.0.extract, 0
call void @llvm.assume(i1 %11)
%12 = icmp sgt i64 %.fca.2.1.extract, 0
call void @llvm.assume(i1 %12)
%.not = icmp sgt i64 %10, %.fca.3.extract
br i1 %.not, label %L221, label %pass
L221: ; preds = %pass, %conversion
ret void
pass: ; preds = %conversion
%.fca.0.0.2.1.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1, 0, 0, 2, 1
%.fca.0.0.2.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1, 0, 0, 2, 0
%.fca.0.0.1.1.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1, 0, 0, 1, 1
%.fca.0.0.1.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1, 0, 0, 1, 0
%.fca.0.0.0.2.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1, 0, 0, 0, 2, 0
%.fca.0.0.0.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %1, 0, 0, 0, 0
%.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 0
%13 = add nsw i64 %10, -1
%14 = udiv i64 %13, %.fca.2.0.extract
%15 = mul i64 %14, %.fca.2.0.extract
%16 = sub i64 %13, %15
%17 = add i64 %16, 1
%18 = add nuw nsw i64 %14, 1
%19 = and i8 %.fca.0.0.1.0.extract, 1
%.not7 = icmp eq i8 %19, 0
%20 = select i1 %.not7, i64 %.fca.0.0.2.0.extract, i64 %17
%21 = and i8 %.fca.0.0.1.1.extract, 1
%.not8 = icmp eq i8 %21, 0
%22 = select i1 %.not8, i64 %.fca.0.0.2.1.extract, i64 %18
%23 = add i64 %22, -1
%24 = mul i64 %23, %.fca.0.0.0.2.0.extract
%25 = add i64 %24, -1
%26 = add i64 %25, %20
%27 = bitcast i8 addrspace(1)* %.fca.0.0.0.0.extract to float addrspace(1)*
%28 = getelementptr inbounds float, float addrspace(1)* %27, i64 %26
%29 = load float, float addrspace(1)* %28, align 4
%30 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
%31 = getelementptr inbounds float, float addrspace(1)* %30, i64 %13
store float %29, float addrspace(1)* %31, align 4
br label %L221
}
Notice the udiv
.
KA.jl:
define ptx_kernel void @_Z20gpu_broadcast_kernel16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi2E5TupleI5OneToI5Int64ES6_EE7NDRangeILi2ES0_S0_S8_S8_EE13CuDeviceArrayI7Float32Li2ELi1EE11BroadcastedI12CuArrayStyleILi2E12DeviceMemoryES7_8identityS3_I8ExtrudedISE_S3_I4BoolSL_ES3_IS5_S5_EEEE({ i64, i32 } %state, { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2) local_unnamed_addr {
conversion:
%.fca.0.0.0.0.extract6 = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 0, 0, 0, 0
%.fca.0.0.1.0.extract7 = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 0, 0, 1, 0
%.fca.1.0.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 0, 0, 0, 0
%.fca.1.0.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 0, 0, 1, 0
%.fca.1.1.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 1, 0, 0, 0
%.fca.1.1.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x [1 x i64]]]] } %0, 1, 1, 0, 1, 0
%3 = icmp sgt i64 %.fca.1.0.0.0.0.extract, 0
call void @llvm.assume(i1 %3)
%4 = icmp sgt i64 %.fca.1.0.0.1.0.extract, 0
call void @llvm.assume(i1 %4)
%5 = icmp sgt i64 %.fca.1.1.0.0.0.extract, 0
call void @llvm.assume(i1 %5)
%6 = icmp sgt i64 %.fca.1.1.0.1.0.extract, 0
call void @llvm.assume(i1 %6)
%7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%8 = zext i32 %7 to i64
%9 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%10 = zext i32 %9 to i64
%11 = udiv i64 %10, %.fca.1.0.0.0.0.extract
%.neg28 = mul i64 %11, %.fca.1.0.0.0.0.extract
%12 = udiv i64 %8, %.fca.1.1.0.0.0.extract
%13 = add nuw nsw i64 %12, 1
%reass.add29 = add i64 %.neg28, %12
%reass.add = sub i64 %10, %reass.add29
%reass.mul = mul i64 %reass.add, %.fca.1.1.0.0.0.extract
%14 = add nuw nsw i64 %8, 1
%15 = add i64 %14, %reass.mul
%16 = mul i64 %11, %.fca.1.1.0.1.0.extract
%17 = add i64 %13, %16
%18 = icmp sgt i64 %15, 0
%19 = icmp sle i64 %15, %.fca.0.0.0.0.extract6
%20 = and i1 %18, %19
%21 = icmp sgt i64 %17, 0
%22 = icmp sle i64 %17, %.fca.0.0.1.0.extract7
%23 = and i1 %21, %22
%24 = and i1 %23, %20
br i1 %24, label %pass6, label %L630
L630: ; preds = %pass6, %conversion
ret void
pass6: ; preds = %conversion
%.fca.0.0.2.1.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2, 0, 0, 2, 1
%.fca.0.0.2.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2, 0, 0, 2, 0
%.fca.0.0.1.1.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2, 0, 0, 1, 1
%.fca.0.0.1.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2, 0, 0, 1, 0
%.fca.0.0.0.2.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2, 0, 0, 0, 2, 0
%.fca.0.0.0.0.extract = extractvalue { [1 x { { i8 addrspace(1)*, i64, [2 x i64], i64 }, [2 x i8], [2 x i64] }], [2 x [1 x i64]] } %2, 0, 0, 0, 0
%.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 2, 0
%.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 0
%25 = sub i64 %10, %.neg28
%26 = mul i64 %12, %.fca.1.1.0.0.0.extract
%27 = sub i64 %8, %26
%28 = mul i64 %25, %.fca.1.1.0.0.0.extract
%29 = add i64 %28, %27
%30 = add i64 %29, 1
%31 = add i64 %16, %12
%32 = add i64 %31, 1
%33 = and i8 %.fca.0.0.1.0.extract, 1
%.not = icmp eq i8 %33, 0
%34 = select i1 %.not, i64 %.fca.0.0.2.0.extract, i64 %30
%35 = and i8 %.fca.0.0.1.1.extract, 1
%.not27 = icmp eq i8 %35, 0
%36 = select i1 %.not27, i64 %.fca.0.0.2.1.extract, i64 %32
%37 = add i64 %36, -1
%38 = mul i64 %37, %.fca.0.0.0.2.0.extract
%39 = add i64 %38, -1
%40 = add i64 %39, %34
%41 = bitcast i8 addrspace(1)* %.fca.0.0.0.0.extract to float addrspace(1)*
%42 = getelementptr inbounds float, float addrspace(1)* %41, i64 %40
%43 = load float, float addrspace(1)* %42, align 4
%44 = mul i64 %31, %.fca.2.0.extract
%45 = add i64 %29, %44
%46 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
%47 = getelementptr inbounds float, float addrspace(1)* %46, i64 %45
store float %43, float addrspace(1)* %47, align 4
br label %L630
}
Still a lot more code, and and additional udiv
, but at least it makes the hypothesis more likely.
https://github.com/JuliaGPU/KernelAbstractions.jl/pull/539 results in only a single sdiv
, which is an improvement, but still not the 0 div
case reported here (which comes from broadcasting a scalar).
With https://github.com/JuliaGPU/KernelAbstractions.jl/pull/539, performance is interestingly worse. For the scalar broadcast:
julia> main()
Old: 30.260 μs (12 allocs: 240 bytes)
New: 72.179 μs (53 allocs: 1.891 KiB)
(as opposed to 61us on master, and 56us on https://github.com/JuliaGPU/KernelAbstractions.jl/pull/518)
For a 2D broadcast:
julia> main()
Old: 154.619 μs (42 allocs: 720 bytes)
New: 165.908 μs (84 allocs: 2.594 KiB)
... which is again a bit slower than before.
Adding some assume
calls to get rid of all exceptions (both the div
related one, the newly added DivError
, and a remaining InexactError
), I get:
julia> main()
Old: 28.200 μs (12 allocs: 240 bytes)
New: 50.119 μs (53 allocs: 1.891 KiB)
julia> main()
Old: 155.308 μs (42 allocs: 720 bytes)
New: 160.239 μs (83 allocs: 2.578 KiB)
The switch to KA.jl significantly slowed down several operations.
CUDA.jl:
permudetims
,broadcast
, and many othershttps://speed.juliagpu.org/changes/?tre=10&rev=6221589f5befec8f6f157a5a5271667dba09d0b6&exe=11&env=1
Metal.jl:
permudetims