Closed bartowski1182 closed 2 months ago
Interestingly when I try Q2_K (and any other K quant it seems, tried Q4_K_S to check) I see a different error:
[ 329/ 339] blk.27.ffn_gate.weight - [ 3584, 18944, 1, 1], type = f32, converting to q2_K .. ggml_validate_row_data: found nan value at block 0
ggml_validate_row_data: found nan value at block 0
ggml_validate_row_data: found nan value at block 0
ggml_validate_row_data: found nan value at block 0
(repeats several times)
Files for recreating it locally are up here: https://huggingface.co/bartowski/Qwen2-7B-Instruct-GGUF/tree/main
If 'Quill' is indeed the leaked version of Qwen2, then something might have changed in llama.cpp that broke Qwen2 conversions into GGUF. I am using a Quill GGUF from Mradermacher's repository about six or so days ago, and it works fine. This assumes that the official release of Qwen2 isn't altered from Quill.
I think that the imatrix contains nan
values. Did it complete without nan ppl?
@slaren are you referring to the --no-ppl option by chance or did you mean something else?
when I remove the --no-ppl the imatrix output does look like this:
llama_cpp_gpu-1 | [1]nan,[2]nan,[3]nan,[4]nan,[5]nan,[6]nan,[7]nan,[8]nan,[9]nan, llama_cpp_gpu-1 | save_imatrix: stored collected data after 10 chunks in /models/Qwen2-7B-Instruct-GGUF/Qwen2-7B-Instruct.imatrix llama_cpp_gpu-1 | [10]nan,[11]nan,[12]nan,[13]nan,[14]nan,[15]nan,[16]nan,[17]nan,[18]nan,[19]nan, llama_cpp_gpu-1 | save_imatrix: stored collected data after 20 chunks in /models/Qwen2-7B-Instruct-GGUF/Qwen2-7B-Instruct.imatrix llama_cpp_gpu-1 | [20]nan,[21]nan,[22]nan,[23]nan,[24]nan,[25]nan,[26]nan,[27]nan,[28]nan,[29]nan, llama_cpp_gpu-1 | save_imatrix: stored collected data after 30 chunks in /models/Qwen2-7B-Instruct-GGUF/Qwen2-7B-Instruct.imatrix
Yes. If you ever get nan
in these values, the imatrix is pretty much guaranteed to be useless. In fact the model is completely unusable as long as that happens.
yikes.. so what the hell am i doing wrong that's not affecting others (or is it affecting others but it's just not obvious cause they didn't calculate imatrix?)
I even tried updating to master and using different datasets
If it happens while using CUDA, it usually means that the model is producing activations that cannot be represented in a float16. We have workarounds to force float32 compute in these cases for the known models that need it, but it is very unusual.
I do imatrix for Qwen2 models on pure CPU, no CUDA build:
ooo so that's gotta be it then.. strangely i converted to f32 instead of f16, but i assume that's not enough?
trying to run my F32 conversion with CUDA gives pure garbage
running it with -ngl 0 gives actual results
It may work with flash attention (-fa
).
The problem seems to be in the KV. This patch should allow it to work:
diff --git a/llama.cpp b/llama.cpp
index 32264a00..6324af70 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -7009,7 +7009,7 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
@@ -15376,6 +15376,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, c
However it also seems to work with flash attn even without GGML_PREC_F32
. @JohannesGaessler do you know what that may be? is the fattn implementation less susceptible to intermediate values out of range of f16?
As always super appreciative of the investigation and solve @slaren , you're awesome <3
is the fattn implementation less susceptible to intermediate values out of range of f16?
Assuming the problem is overflows rather than underflows I think the most likely reason is that my FlashAttention implementations initially scale Q with 1.0f/sqrtf(float(n_embd_head))
(because it's faster) while regular attention does the KQ matrix multiplication first and applies the scale to the result. But this factor is not very large so there is a good chance that for some specific inputs you would still get overflows.
@slaren would ROCm also be experiencing this bug or is it just in CUDA code? Curious if other odd bugs out there can be explained by this
Heres my current output for qwen2. llama.cpp seems to be utterly broken right now.
Here is my run command:
main -n -1 -m "qwen2-7b-instruct-q8_0.gguf" ^
--color --multiline-input -t 16 -ngl "900" ^
--log-disable -c 8192 --temp 0.6 --interactive-first ^
--interactive
@dillfrescott can you try adding -fa and seeing if that helps?
@bartowski1182 Damn, that totally fixed it! What does that flag do??
Enables flash attention, apparently it runs the operations in a different order that prevents the issues
Interesting
Heres my current output for qwen2. llama.cpp seems to be utterly broken right now.
Here is my run command:
main -n -1 -m "qwen2-7b-instruct-q8_0.gguf" ^ --color --multiline-input -t 16 -ngl "900" ^ --log-disable -c 8192 --temp 0.6 --interactive-first ^ --interactive
./main -m qwen2-7b-instruct-q5_k_m.gguf -n 512 --color -i -cml -f prompts/chat-with-qwen.txt
This is what I do. And are you using GPU and it might be a CUDA issue? Not sure. This thing also happens in Ollama for some users, but my friend inputs /set num_gpu 0
and finds it work. Things are a bit strange for me...
@JustinLin610 I am using a 4090 and have cuda version 12.5 installed. It usually works with every other model so I'm just not sure why.
@dillfrescott can you try adding -fa and seeing if that helps?
Any reason why enabling flash_attn for the imatrix calculation the time to complete increased from about 2.5h to 5h (qwen2-72b)? I thought I'd be faster... does it mean that without fa the imatrix was not computed properly?
@slaren would ROCm also be experiencing this bug or is it just in CUDA code? Curious if other odd bugs out there can be explained by this
Not him but the fundamental issue is that the KQ matrix has values outside the FP16 range. On Volta or newer or RX 7000 or newer llama.cpp does all matrix multiplications with a batch size > 64 using FP16 cuBLAS/HIPBLAS so you get garbage results. The fix works by instead calculating the KQ matrix with FP32 precision. The huge values from the KQ matrix are then fed to softmax which brings them into the 0-1 range again and fixes the numerical issues.
You should be able to fix the numerical issues with better performance on Ampere or newer by calculating the KQ matrix as BF16 instead of FP16 (with some precision loss which is probably negligible if you're going to apply softmax anyways). Annoyingly there is to my knowledge no way to make cuBLAS return the result as FP32 (even though that is the only output format that tensor cores actually support) because then we would also save some time on the conversion back to FP32.
Any reason why enabling flash_attn for the imatrix calculation the time to complete increased from about 2.5h to 5h (qwen2-72b)? I thought I'd be faster... does it mean that without fa the imatrix was not computed properly?
Depends on what hardware you have, on AMD large batch FA performance is bad. I have received some reports about bad performance with partial offloading but so far I have never been able to reproduce this.
And note that I think that just adding -fa
is not a reliable fix so there is a good chance that you will at some point still get NaNs.
Depends on what hardware you have, on AMD large batch FA performance is bad. I have received some reports about bad performance with partial offloading but so far I have never been able to reproduce this.
for normal inference works fine fa, it's running on an RTX3090 on Linux so 12 layers only on VRAM I never got any NaN, either with or without fa
With -fa
:
llama_new_context_with_model: n_batch = 204
llama_new_context_with_model: n_ubatch = 204
llama_new_context_with_model: flash_attn = 1
llama_new_context_with_model: freq_base = 1000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init: CUDA_Host KV buffer size = 68.00 MiB
llama_kv_cache_init: CUDA0 KV buffer size = 12.00 MiB
llama_new_context_with_model: KV self size = 80.00 MiB, K (f16): 40.00 MiB, V (f16): 40.00 MiB
llama_new_context_with_model: CUDA_Host output buffer size = 0.58 MiB
llama_new_context_with_model: CUDA0 compute buffer size = 2500.71 MiB
llama_new_context_with_model: CUDA_Host compute buffer size = 6.60 MiB
llama_new_context_with_model: graph nodes = 2487
llama_new_context_with_model: graph splits = 956
system_info: n_threads = 6 / 12 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 |
compute_imatrix: tokenizing the input ..
compute_imatrix: tokenization took 117.352 ms
compute_imatrix: computing over 100 chunks with batch_size 204
ggml_backend_cuda_graph_compute: CUDA graph update failed
ggml_backend_cuda_graph_compute: disabling CUDA graphs due to too many consecutive updates
ggml_backend_cuda_graph_compute: CUDA graph update failed
compute_imatrix: 177.86 seconds per pass - ETA 4 hours 56.43 minutes
Without -fa
:
llama_new_context_with_model: n_ctx = 224
llama_new_context_with_model: n_batch = 204
llama_new_context_with_model: n_ubatch = 204
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base = 1000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init: CUDA_Host KV buffer size = 59.50 MiB
llama_kv_cache_init: CUDA0 KV buffer size = 10.50 MiB
llama_new_context_with_model: KV self size = 70.00 MiB, K (f16): 35.00 MiB, V (f16): 35.00 MiB
llama_new_context_with_model: CUDA_Host output buffer size = 0.58 MiB
llama_new_context_with_model: CUDA0 compute buffer size = 2500.71 MiB
llama_new_context_with_model: CUDA_Host compute buffer size = 6.57 MiB
llama_new_context_with_model: graph nodes = 2806
llama_new_context_with_model: graph splits = 956
system_info: n_threads = 6 / 12 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0
| BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 |
compute_imatrix: tokenizing the input ..
compute_imatrix: tokenization took 119.284 ms
compute_imatrix: computing over 100 chunks with batch_size 204
ggml_backend_cuda_graph_compute: CUDA graph update failed
ggml_backend_cuda_graph_compute: disabling CUDA graphs due to too many consecutive updates
ggml_backend_cuda_graph_compute: CUDA graph update failed
compute_imatrix: 98.19 seconds per pass - ETA 2 hours 43.65 minutes
@slaren Even with an imatrix generated with c9ee7118d5644dd3df70ea6878b36a9761616aab and -fa (no NaNs) quantization still fails:
[ 11/ 479] blk.0.ffn_down_exps.weight - [ 2560, 3584, 64, 1], type = bf16, converting to q2_K .. ggml_validate_row_data: found nan value at block 16
ggml_validate_row_data: found nan value at block 46
ggml_validate_row_data: found nan value at block 6
ggml_validate_row_data: found nan value at block 6
llama_model_quantize: failed to quantize: quantized data validation failed
I'm guessing this is due to the application of the imatrix to the weights overflowing?
I don't know, I would need to be able to reproduce it to be able to investigate why that happens.
@CISC can you share more details? I didn't get issues with just -fa
What dataset, and are you converting to f16 or f32 first?
@bartowski1182 I'm quantizing from bf16, the imatrix is generated from 250K tokens of groups-merged-enhancedV3 concatenated with wiki.train.raw.
it's just me having almost all the tensors requiring fallback when trying k/i-quants?
This is 0.5B but it does almost the same on all qwen2:
llama_tensor_get_type : tensor cols 896 x 128 are not divisible by 256, required for q3_K - using fallback quantization iq4_nl
...
llama_tensor_get_type : tensor cols 896 x 896 are not divisible by 256, required for q4_K - using fallback quantization q5_0
...
llama_tensor_get_type : tensor cols 896 x 128 are not divisible by 256, required for q5_K - using fallback quantization q5_1
...
llama_model_quantize_internal: WARNING: 144 of 168 tensor(s) required fallback quantization
I'm using the i-matrix on a 3090 with fa enabled Doesn't matter if converted to fp32 or fp16
@slaren Ok, I believe I've located the issue. Whole blocks of x contain -0.0/0.0 weights here:
https://github.com/ggerganov/llama.cpp/blob/0c27e6f62eea80140daf152d7b6c154466614e5c/ggml-quants.c#L2181-L2186
which ultimately causes a divide-by-zero in make_qp_quants
(taking sw as input):
https://github.com/ggerganov/llama.cpp/blob/0c27e6f62eea80140daf152d7b6c154466614e5c/ggml-quants.c#L2161
I'm not sure that 0-weights should be an actual issue, so the fix is probably to check that suml2
division will not cause a NaN:
return suml2 ? sumlx/suml2 : 0.0f;
Yep, confirmed working on Q2_K
at least, now onto all the others...
I see several other places that have potential divide-by-zero, however a few places have the following, does this really work? https://github.com/ggerganov/llama.cpp/blob/0c27e6f62eea80140daf152d7b6c154466614e5c/ggml-quants.c#L1681
Edit: Yes it does.
Don't see why it wouldn't, if the value is non 0 divide by it, otherwise 0.0f
Don't see why it wouldn't, if the value is non 0 divide by it, otherwise 0.0f
Yeah, I just wasn't sure that was applicable to -0.0f.
Hmmm, IQ1_S fails spectacularly: besti1 >= 0 && besti2 >= 0 && best_shift != 0
...
I guess that makes sense since the sum of the weights are 0, but shouldn't crash. :(
Sigh, there's lots more, now I'm hitting the grid_index >= 0
assert too.
@CISC Do you have branch to test it?
@mann1x #7825 But there are other issues, the imatrix will magically contain NaNs, see #7816
@JustinLin610 I am using a 4090 and have cuda version 12.5 installed. It usually works with every other model so I'm just not sure why.
Oh -cml
is for chatml format. I just saw its missing in your command. But I think Bartowski mentioned the real problem.
After some fixing, fudging and a lot of debugging, I have found the reason for all the Qwen2 issues.
The models contain weights <±1e-25, sometimes quite a lot of them, depending on which format you converted them to this could cause issues in itself, but would cause NaNs during 16-bit KQ multiplications or certain quantizations for basically the same reason.
Unfortunately this renders models like 57B-A14B completely impossible to quantize to but a few that will not trigger this issue and/or doesn't require an imatrix.
Unfortunately this renders models like 57B-A14B completely impossible to quantize to but a few that will not trigger this issue and/or doesn't require an imatrix.
I made some quants without imatrix for ollama: https://ollama.com/mannix/qwen2-57b/tags
They seem to work properly but I didn't have time to test them extensively. Should I double check?
I made some quants without imatrix for ollama: https://ollama.com/mannix/qwen2-57b/tags They seem to work properly but I didn't have time to test them extensively. Should I double check?
No, those quants should be ok, at least if you made them from F32.
No, those quants should be ok, at least if you made them from F32.
Thanks for the confirmation! Yes indeed I picked f32 for bin as suggested,
@CISC Thanks for your great work!
I'd like to understand the scope of impact this issue has, is quants without imatrix OK with the F32 KQ patch or they are not safe either? Additionally, is it possible that this issue might cause some quants to appear functional while the output quality degrades for certain inputs?
I'd like to understand the scope of impact this issue has, is quants without imatrix OK with the F32 KQ patch or they are not safe either? Additionally, is it possible that this issue might cause some quants to appear functional while the output quality degrades for certain inputs?
I'd imagine any quantization will severely degrade the weights in question, how much of an issue in the final results that will be remains to be seen.
I hope someone points the qwen2 authors to this discussion.
They're aware of it, but is there anything they can or should do..?
What happened?
When attempting to quantize Qwen2 7B instruct to IQ2_XS I get the following assert:
Anything I can provide to debug? Uploading the f32 file and imatrix now for recreation
Attempting IQ2_S now,
will update if it fails in the same wayupdate: it fails in the same way on the same blockName and Version
Version b3086, ubuntu 22.04
What operating system are you seeing the problem on?
Linux
Relevant log output
(PS: is this a high severity or medium/low?)