Open contentis opened 2 days ago
With LLAMA_CUDA_FORCE_DMMV
I see a significant improvement from these changes:
GPU | Model | Test | t/s master | t/s q4_k_performance | Speedup |
---|---|---|---|---|---|
RTX 3090 Ti | llama 7B Q4_K_M | tg128 | 104.36 | 127.25 | 1.22 |
However, without LLAMA_CUDA_FORCE_DMMV
mmvq is used instead, which is significantly faster:
model | size | params | backend | ngl | test | t/s |
---|---|---|---|---|---|---|
llama 7B Q4_K - Medium | 3.80 GiB | 6.74 B | CUDA | 99 | tg128 | 146.10 ± 1.12 |
For this format I think mmvq should be automatically used for compute capability >= 6.1.
Similar results for RTX 2060:
LLAMA_CUDA_FORCE_DMMV=1 LLAMA_CUDA=1 ./scripts/compare-commits.sh master pr/8136 -m ./models/llama-7b-v2/ggml-model-q4_k.gguf -p 0 -ngl 99
GPU | Model | Test | t/s master | t/s pr/8136 | Speedup |
---|---|---|---|---|---|
RTX 2060 SUPER | llama 7B Q4_K_M | tg128 | 29.47 | 40.27 | 1.37 |
LLAMA_CUDA=1 make -j && ./llama-bench -m ./models/llama-7b-v2/ggml-model-q4_k.gguf -p 0 -ngl 99
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no ggml_cuda_init: found 1 CUDA devices: Device 0: NVIDIA GeForce RTX 2060 SUPER, compute capability 7.5, VMM: yes | model | size | params | backend | ngl | test | t/s |
---|---|---|---|---|---|---|---|
llama 7B Q4_K - Medium | 3.80 GiB | 6.74 B | CUDA | 99 | tg128 | 69.08 ± 0.24 |
build: 50732c06 (3236)
Use float4 to force aligned reads of 128bit to increase memory bandwidth
I don't think that's the reason. __builtin_assume_aligned
does not achieve any speedup. And unlike with e.g. half2
there is no hardware support for float4
; the compiler will simply convert it to float instructions. I think the reason for the performance difference is rather that with the code on master the compiler fails to group or schedule the memory accesses to y1
and y2
in the optimal way.
In any case, I can confirm that the performance increases with this change and that most of the difference comes from the float4
change.
And unlike with e.g. half2 there is no hardware support for float4; the compiler will simply convert it to float instructions.
Unless I am completely misremembering the CUDA documentation definitely said that the float4
datatype effectively is just 4 floats. However, looking at the PTX code it seems that there is a dedicated instruction for loading a chunk of 16 bytes:
1031 ld.global.nc.v4.f32 {%f9, %f10, %f11, %f12}, [%rd31+-512];
1032
1033 ld.global.nc.v4.f32 {%f17, %f18, %f19, %f20}, [%rd31+-384];
1034
1035 ld.global.nc.v4.f32 {%f25, %f26, %f27, %f28}, [%rd31];
1036
1037 ld.global.nc.v4.f32 {%f33, %f34, %f35, %f36}, [%rd31+128];
Assuming this can also be applied to other kernels this is definitely a very good find.
@JohannesGaessler float4 issues a single instruction to read 128bit. In theory, the compiler should be capable of understanding this from the original code (w.o. having to use "unsafe" methods as reinterpret cast), but the compiler often is imperfect.
@slaren My plan was to look at dequantize_mul_mat_vec
(Q4_0) next and see if I can find opportunities for optimization.
Would you say mmvq is used more frequently? Trying to understand what is the most common format.
Unless I am completely misremembering the CUDA documentation definitely said that the float4 datatype effectively is just 4 floats. However, looking at the PTX code it seems that there is a dedicated instruction for loading a chunk of 16 bytes:
You are correct. A float4 is 4 floats in 4 registers and there is supposed for aligned vectorized loads (LD*.(8|16|32|64|128) with numbers being bits.
Loads / Stores are split into transactions of sectors (32 bytes) with one sector being processed each clock. If a warp accesses multiple sectors they get serialized which increases MIO utilization. When simplifying things a bit memory efficiency can be considered as (number_of_bytes_accessed / (number_of_sectors_accessed*32))
.
In this case 4 contiguous floats have been read with 4 LD instructions. Two contiguous warps have a byte offset of 16 bytes resulting in reading bytes 0-3, 16-19 within a sector. Thus memory efficiency was 8/32=0.25. Using float4 instead of float increased this metric to 1.0.
Assuming this can also be applied to other kernels this is definitely a very good find.
A full sweep with Nsight Compute through all kernels / kernel dimensions of the most important networks will reveal candidates where the optimization might help. FB utilization is the metric to look for most of the time for the llama.cpp kernels. A utilization >95% is perfect, >90% is pretty good, Everything below might be using memory inefficient and thus benefit from vectorized loads and/or reordering of memory accesses.
@slaren My plan was to look at
dequantize_mul_mat_vec
(Q4_0) next and see if I can find opportunities for optimization. Would you say mmvq is used more frequently? Trying to understand what is the most common format.
Yes, mmvq is used more frequently. At this point, the dmmv kernels are only with GPUs that don't support dp4a, cc < 6.1.
Would you say mmvq is used more frequently? Trying to understand what is the most common format.
All NVIDIA GPUs starting with compute capability 6.1 have the __dp4a
instruction which does per-byte integer dot products. This is what is used in the mul_mat_vec_q
kernels and unless you are compiling with LLAMA_CUDA_FORCE_DMMV
the MMVQ kernels are always used if __dp4a
is available. The most modern GPU that would benefit from optimizations to the dequantize_mul_mat_vec
kernels is the P100 because it only has compute capability 6.0. With MMVQ the activations are quantized to 8 bit so you can load 4 values with a single 32 bit value and those values align nicer with the quantized data in the weights.
A full sweep with Nsight Compute through all kernels / kernel dimensions of the most important networks will reveal candidates where the optimization might help.
Basically the only kernels that are performance relevant are mul_mat_vec_q
(matrix multiplication with batch size 1-8) and mul_mat_q
(matrix multiplication with batch size > 8) as well as the ggml FlashAttention kernels. The former two deal with quantized weights and right now the main issue in terms of bandwidth is that the data layout is not very CUDA friendly. By design the same data layout (array of structs) is used for all backends and as a consequence all of the quantized data blocks are aligned to only 2 or 4 bytes. The blocks are also very small with only 16 or 32 quantized values per scale (for GPTQ it's I think 128).
For MMVQ I've been thinking it would maybe be possible to just load the data contiguously and apply a bit mask instead of loading the data from 2/4 byte values that are not 100% contiguous.
FB utilization is the metric to look for most of the time for the llama.cpp kernels. A utilization >95% is perfect, >90% is pretty good
I currently don't have access to an instance of NSight compute, MMVQ had ~90% "speed of light" memory utilization, I don't know FB specifically.
By design the same data layout (array of structs) is used for all backends
This is not really the design, backends have the ability to change the layout of the tensor data. I can go into more detail if you think that could improve performance significantly, but essentially the backends can convert the data to whatever layout they want during the set_tensor
function, although they should also provide a way to convert it to the ggml layout in get_tensor
.
I previously made prototypes where I converted the data to struct of arrays layout but I was not able to get better performance for MMVQ (though it's always possible that I just did it wrong). For MMQ it would maybe be useful because for efficient use of asynchronous data loading you need 16 byte alignment. But right now the int8 tensor core utilization is only 35% so there are probably still other problems that would need to be fixed. And unless you drop support for partial offloading completely you would need to implement and compile two separate instances for loading data per quantization format so I'm thinking that changing the data layout is comparatively a lot of work for the potential benefits.
And unless you drop support for partial offloading completely you would need to implement and compile two separate instances for loading data per quantization format
I don't think this needs to be the case, during partial offloading the weights are also copied to VRAM using the backend interface, and they can be converted to a different layout during the copy. As long as it is only a change in layout and does not require any expensive computations, I don't think it would affect performance significantly.
In my prototype I did the conversion via host->device cudaMemcpy2D
and the performance was very terrible but maybe device->device is acceptable; I'd have to test it.
Given that SM load is low for the kernels of interest instead of adding complexity to the codebase we can also add support for unaligned loads in CUDA:
__device__ int load_unaligned_int(void const* address)
{
ptrdiff_t address_int = reinterpret_cast<ptrdiff_t>(address);
uint32_t const* address_int32 = reinterpret_cast<const uint32_t*>(address_int & ~size_t(3));
uint32_t offset = address_int & 3;
uint32_t dword1 = *address_int32;
if (offset) {
auto dword2 = *(address_int32 + 1);
asm volatile("prmt.b32.f4e %0, %1, %2, %3;" : "=r"(dword1) : "r"(uint32_t(dword1)), "r"(uint32_t(dword2)), "r"(offset));
}
return dword1;
}
is as good as it can get loading 32-bit from an unaligned address. Given that the offset is 2 byte aligned there is a 50% chance that only a single load is required and if the alignment is not given 2 loads have to be done anyway and the result has to be combined as well.
https://godbolt.org/z/6zdT9osb3 has the code and SASS. The unaligned load logic essentially adds only 2 LOP3 instructions and that's it. When streaming contiguous data only a single additional load would be required for the whole stream.
LOP3.LUT P0, R0, R4.reuse, 0x3, RZ, 0xc0, !PT ; required
IMAD.MOV.U32 R7, RZ, RZ, R5 ; might be in the SASS or not depending on the register allocator strategy
LOP3.LUT R6, R4, 0xfffffffc, RZ, 0xc0, !PT ; required
ULDC.64 UR4, c[0x0][0x118] ; won't be in the SASS if inlined
LD.E R4, [R6.64] ; load 1
@P0 LD.E R3, [R6.64+0x4] ; load 2
@P0 PRMT.F4E R4, R4, R0, R3 ; combine the two ints
RET.ABS.NODEC R20 0x0 ; just for the function
.L_x_0:
BRA `(.L_x_0)
The same can be done with 64-bit and 128-bit loads as long as there are free cycles in the ALU or FP unit to move around registers.
Relative to the current master code that loads the data as 2 16 bit values I'm measuring a speedup of ~1%. Is there a reason why you're using inline PTX assembly instead of __byte_perm (which I think does the same thing)?
Changes:
Performance Using llama-bench I measured the end-to-end speedup Device 0: NVIDIA RTX 6000 Ada Generation, compute capability 8.9, VMM: yes