wangzhaode / llm-export

llm-export can export llm model to onnx.
Apache License 2.0
187 stars 21 forks source link

[Request]: Help in adding support for Models with Grouped Query Attention (GQA) #21

Open Nick-infinity opened 7 months ago

Nick-infinity commented 7 months ago

Hello, I am trying to add support for models with GQA eg. Tiny llama The indicator for grouped query attention is when num_key_value_heads < num_attention_heads in config.json file

For TinyLlama model this is the Config.json file


"attention_bias": false,
--
  | "bos_token_id": 1,
  | "eos_token_id": 2,
  | "hidden_act": "silu",
  | "hidden_size": 2048,
  | "initializer_range": 0.02,
  | "intermediate_size": 5632,
  | "max_position_embeddings": 2048,
  | "model_type": "llama",
  | "num_attention_heads": 32,
  | "num_hidden_layers": 22,
  | "num_key_value_heads": 4,
  | "pretraining_tp": 1,
  | "rms_norm_eps": 1e-05,
  | "rope_scaling": null,
  | "rope_theta": 10000.0,
  | "tie_word_embeddings": false,
  | "torch_dtype": "bfloat16",
  | "transformers_version": "4.35.0",
  | "use_cache": true,
  | "vocab_size": 32000

I have modified the llama2 class and my self.past_kv_shape is as follow : [22, 2, 1, 4, 0, 64] . where 64 = hidden_size / (num_attention_heads)

Is this modification correct? is there anything else i need to modify eg. attention_maskt, positon_id etc?

I am able to convert the model successfully with this modification but I think generated output is not correct. How can I verify if model is working correctly.

@wangzhaode @v0jiuqi I need your guidance please

wangzhaode commented 7 months ago

you can add --export_test to verify the onnx model. --export_test will run onnx with onnxrumtime and compare with torch.

Nick-infinity commented 7 months ago

you can add --export_test to verify the onnx model. --export_test will run onnx with onnxrumtime and compare with torch.

Yes, Thank you for your response. I tried --export_test and I can see "onnx test SUCCESS" for lm , embedding and 22 blocks of my model.

I added it to mnn-lmm and the output seems to be correct.

Now I have another model (tinyllama_multilingual) which has extended vocabulary. i.e. It has same block nums and attention as Tinyllama 1.1B but the vocab size is 160984 instead of 32000.

I converted the model in the same fashion as tinyllama 1.1B. But when I run this model in mnn, it generates all garbage results. I checked the tokenizer.cpp and fed the tokens manually as well but the output token ids corresponds to incorect string values.

How should I debug this and what should I check. The only difference is a large vocabulary size as compared to Tinyllama1.1B and everything is same.

This means that lm head has a convolution of size 160984 x 2048 instead of 32000 x 2048.

I am doubting that some matmul or convolution layer is outputting incorrect results at a place where embedding is involved or maybe I added wrong parameter in key_value_shape when converting the model and GQA is not processed correctly

How should I debug this?

Nick-infinity commented 7 months ago
onnx test SUCCESS
Don't has bizCode, use MNNTest for default
Start to Convert Other Model Format To MNN Model..., target version: 2.8
[15:30:40] :46: ONNX Model ir version: 8
[15:30:40] :47: ONNX Model opset version: 15
Start to Optimize the MNN Net...
Nick-infinity commented 7 months ago

My custom model with extended vocab config.json file

{
  "architectures": [
    "LlamaForCausalLM"
  ],
  "auto_map": {
    "AutoModelForCausalLM": "modeling_llama.LlamaForCausalLM"
  },
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 5632,
  "max_position_embeddings": 2048,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 22,
  "num_key_value_heads": 4,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.36.2",
  "use_cache": false,
  "vocab_size": 160984
}
Nick-infinity commented 7 months ago

Hey @wangzhaode , Can you give me some clue what should I fix to get my GQA model with large vocab size working?

wangzhaode commented 7 months ago

Can you convert large vocab size lm.onnx to lm.mnn ?

Nick-infinity commented 7 months ago

Can you convert large vocab size lm.onnx to lm.mnn ?

Yes I can Convert the lm.onn to lm.mnn

lm.mnn

image

lm.onnx

image
Nick-infinity commented 7 months ago

I dumped my block_0.mnn to block_0.json. I can see that the bias values are all 0.0 in my json. While the bias values are non zero in qwen1.8B block_0 json file. Does this give any hint?

Also another doubt, the model is int4 quantized but I can see the "buffer": data is clipped between -127 to 127 i.e. int8 . I have used assymetric quantization though which might have non zero bias

This is the case with qwen model too. I am not sure if this is just json printing bug or if mnn process int8 quanized values again while model loading and create int4 weights in memory

Nick-infinity commented 7 months ago

I think I might have found the issue. @wangzhaode All blocks_#.mnn should be of same size right? My block_0.mnn and block_1.mnn and block_21.mnn are all of diff size. When I checked the len of "buffer" in json using MnnDump2Json, First 2 convolutions of all blocks have diff number of elements in "buffer".

Is this behaviour correct?

Nick-infinity commented 7 months ago

I cross Checked the llama2 7b model , I was expecting that size of the blocks should be same as each block is exactly the same model with same dimension for each layer. But there is a diff in the Buffer len and type of the convolutions present in block_0.mnn and block_1.mnn of llama2_7b model.

The Block 0 formatted json data is :

Name: /block/self_attn/q_proj/MatMul_output_0__matmul_converted
Len of Bias:  4096
Len of alpha: 8192
Len of Buffer: 5889803
Input Count: 4096
Output Count: 4096
Read Type: 4096
Type: 2

Name: /block/self_attn/k_proj/MatMul_output_0__matmul_converted
Len of Bias:  4096
Len of alpha: 8192
Len of Buffer: 6950749
Input Count: 4096
Output Count: 4096
Read Type: 4096
Type: 2

Name: /block/self_attn/v_proj/MatMul_output_0__matmul_converted
Len of Bias:  4096
Len of alpha: 8192
Len of Buffer: 8388630
Input Count: 4096
Output Count: 4096
Read Type: 4096
Type: 1

Name: /block/self_attn/o_proj/MatMul_output_0__matmul_converted
Len of Bias:  4096
Len of alpha: 8192
Len of Buffer: 8137910
Input Count: 4096
Output Count: 4096
Read Type: 4096
Type: 2

Name: /block/mlp/gate_proj/MatMul_output_0__matmul_converted
Len of Bias:  11008
Len of alpha: 22016
Len of Buffer: 22544406
Input Count: 4096
Output Count: 11008
Read Type: 11008
Type: 1

The Block 1 formatted json data is :

Name: /block/self_attn/q_proj/MatMul_output_0__matmul_converted
Len of Bias:  4096
Len of alpha: 8192
Len of Buffer: 8388630
Input Count: 4096
Output Count: 4096
Read Type: 4096
Type: 1

Name: /block/self_attn/k_proj/MatMul_output_0__matmul_converted
Len of Bias:  4096
Len of alpha: 8192
Len of Buffer: 8388630
Input Count: 4096
Output Count: 4096
Read Type: 4096
Type: 1

Name: /block/self_attn/v_proj/MatMul_output_0__matmul_converted
Len of Bias:  4096
Len of alpha: 8192
Len of Buffer: 8388630
Input Count: 4096
Output Count: 4096
Read Type: 4096
Type: 1

Name: /block/self_attn/o_proj/MatMul_output_0__matmul_converted
Len of Bias:  4096
Len of alpha: 8192
Len of Buffer: 8388630
Input Count: 4096
Output Count: 4096
Read Type: 4096
Type: 1

Name: /block/mlp/gate_proj/MatMul_output_0__matmul_converted
Len of Bias:  11008
Len of alpha: 22016
Len of Buffer: 22544406
Input Count: 4096
Output Count: 11008
Read Type: 11008
Type: 1
Nick-infinity commented 7 months ago

Okay so Type 1 and Type 2 Quantization for Convolutions are based on Sparse vs Compressed blob size for quantized weights and I am guessing MNN has both the implementations for Convolutions.

This means the block models can have diff sizes depending on if Sparsed Version of quantized weights are smaller or Compressed Version of Weights are smaller.

I tried forcing sparse vs compressed convolutions manually In the block models conversion but it didnt help much in accuracy.

Another interesting fact is that For quantization, MNN Does not use block wise quantization as llama cpp where block size 32 . It uses The Block size as the Width of the weights in Kernel of the Convolutions. E.g. If you have a matmul with weight 2048 x 256 then block size will be 256. This will definitely be less accurate than a smaller block size of 32 that is used in Group quants.

wangzhaode commented 7 months ago

This is the case with qwen model too. I am not sure if this is just json printing bug or if mnn process int8 quanized values again while model loading and create int4 weights in memory

This is no problem, no matter how many bits are quantized, it will be displayed in the form of int8 in json

wangzhaode commented 7 months ago

Okay so Type 1 and Type 2 Quantization for Convolutions are based on Sparse vs Compressed blob size for quantized weights and I am guessing MNN has both the implementations for Convolutions.

This means the block models can have diff sizes depending on if Sparsed Version of quantized weights are smaller or Compressed Version of Weights are smaller.

I tried forcing sparse vs compressed convolutions manually In the block models conversion but it didnt help much in accuracy.

Another interesting fact is that For quantization, MNN Does not use block wise quantization as llama cpp where block size 32 . It uses The Block size as the Width of the weights in Kernel of the Convolutions. E.g. If you have a matmul with weight 2048 x 256 then block size will be 256. This will definitely be less accurate than a smaller block size of 32 that is used in Group quants.

Yes! Type 2 means MNN using sparse weight. The quantize method of MNN is channel-wise quantization, less accurate than block wise quantization. Maybe we will support block quantization next version.

wangzhaode commented 7 months ago

Can you give me the tinyllama_multilingual model url ?

I'll test and debug for you.

Nick-infinity commented 7 months ago

I see that you have added the support for TinyLlama conversion. This is exactly the way I did and got it working. My multilingual model is trained right now. I will try to share it once its ready with you.

My model is exactly same as tiny llama 1.1B but with vocab size of 160984 instead of 32000. . I got the tinyllama1.1b working correctly with my changes as I mentioned

wangzhaode commented 7 months ago

I see that you have added the support for TinyLlama conversion. This is exactly the way I did and got it working. My multilingual model is trained right now. I will try to share it once its ready with you.

My model is exactly same as tiny llama 1.1B but with vocab size of 160984 instead of 32000. . I got the tinyllama1.1b working correctly with my changes as I mentioned

OK!👌

Vocab size that is too large will cause overflow bugs when ONNX convert to MNN, this bug has been fixed internal and will be public to github.

Nick-infinity commented 7 months ago

I see that you have added the support for TinyLlama conversion. This is exactly the way I did and got it working. My multilingual model is trained right now. I will try to share it once its ready with you. My model is exactly same as tiny llama 1.1B but with vocab size of 160984 instead of 32000. . I got the tinyllama1.1b working correctly with my changes as I mentioned

OK!👌

Vocab size that is too large will cause overflow bugs when ONNX convert to MNN, this bug has been fixed internal and will be public to github.

Thank you very much for your support and sharing this information. I was also skeptical of overfow for large vocab size. Just one doubt, Qwen Model has vocab size of 151,936. This size is comparable to my size of 160,984. Can you let me know what is "too large" size that cause overflow if possible. Also, when Can we see the change in public github?

wangzhaode commented 7 months ago

I see that you have added the support for TinyLlama conversion. This is exactly the way I did and got it working. My multilingual model is trained right now. I will try to share it once its ready with you. My model is exactly same as tiny llama 1.1B but with vocab size of 160984 instead of 32000. . I got the tinyllama1.1b working correctly with my changes as I mentioned

OK!👌 Vocab size that is too large will cause overflow bugs when ONNX convert to MNN, this bug has been fixed internal and will be public to github.

Thank you very much for your support and sharing this information. I was also skeptical of overfow for large vocab size. Just one doubt, Qwen Model has vocab size of 151,936. This size is comparable to my size of 160,984. Can you let me know what is "too large" size that cause overflow if possible. Also, when Can we see the change in public github?

vocab_size * hidden_size * sizeof(float) > INT_MAX will cause the BUG. Qwen export using the branch with a patch not siutable for publish.

Nick-infinity commented 7 months ago

Understood, Thanks. I am happy that this bug was caught and fix is in progress. _I am just guesing that size of the data structure storing the embedding weight is clipped to INTMAX in MNN somewhere. 😅

Although 160984 x 2048 x 4 = 1,31,87,80,928 This is less than INT_MAX i.e 2,14,74,83,647

Nick-infinity commented 7 months ago

@wangzhaode I tried finding the overflow in MNN model export as per your advise. I was not able to locate an overflow as vocab_size hidden_size sizeof(float) < INT_MAX for my case. I tired checking the convolution layer for lm head and the export optimizer MatMul -> Convultion2D. I am building MNN from source and using the MNNconvert binary instead of python wheel Can you please point me to the code which has the bug so that I can try to fix it myself unless the next release of MNN. We have some dependency on this fix to get a project working. 🙌🏻

Nick-infinity commented 7 months ago

@wangzhaode any help or guidance regarding this?

@wangzhaode I tried finding the overflow in MNN model export as per your advise. I was not able to locate an overflow as vocab_size hidden_size sizeof(float) < INT_MAX for my case. I tired checking the convolution layer for lm head and the export optimizer MatMul -> Convultion2D. I am building MNN from source and using the MNNconvert binary instead of python wheel Can you please point me to the code which has the bug so that I can try to fix it myself unless the next release of MNN. We have some dependency on this fix to get a project working. 🙌🏻

Nick-infinity commented 7 months ago

I printed the all input to my ArgMax layer for lm head. The maximum value is always at index "1" outof 160984 vocab size. This is very weird and I think lm convolution weights are is doing something wrong. It might be caused due to the same issue that you mentioned. You mentioned that overflow is happening at ONNX -> MNN conversion. Not able to locate it yet.