Open Chillee opened 4 months ago
convert_hf_checkpoint.py fails with following error for gemma-7b -
Model config {'block_size': 2048, 'vocab_size': 256000, 'n_layer': 28, 'n_head': 16, 'dim': 3072, 'intermediate_size': 24576, 'n_local_heads': 16, 'head_dim': 192, 'rope_base': 10000, 'norm_eps': 1e-05}
Traceback (most recent call last):
File "scripts/convert_hf_checkpoint.py", line 111, in <module>
convert_hf_checkpoint(
File "......../site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "scripts/convert_hf_checkpoint.py", line 91, in convert_hf_checkpoint
q = permute(q, config.n_head)
File "scripts/convert_hf_checkpoint.py", line 62, in permute
w.view(n_head, 2, config.head_dim // 2, dim)
RuntimeError: shape '[16, 2, 96, 3072]' is invalid for input of size 12582912
The issue is with ModelArgs.head_dim
being computed as 192 but HF config dictates it to be 256. I tried by forcing it but then it fails during inferencing with the following error for each layer -
size mismatch for layers.0.attention.wo.weight: copying a param with shape torch.Size([3072, 4096]) from checkpoint, the shape in current model is torch.Size([3072, 3072]).
Same error as @shaahji saw above. I tried with config {'dim': 3072, 'vocab_size': 256000, 'n_layer': 28, 'n_head': 16, 'n_local_heads': 16, 'intermediate_size': 24576}
according to gemma-7b/config.json
{
"architectures": [
"GemmaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 2,
"eos_token_id": 1,
"head_dim": 256,
"hidden_act": "gelu",
"hidden_size": 3072,
"initializer_range": 0.02,
"intermediate_size": 24576,
"max_position_embeddings": 8192,
"model_type": "gemma",
"num_attention_heads": 16,
"num_hidden_layers": 28,
"num_key_value_heads": 16,
"pad_token_id": 0,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 10000.0,
"torch_dtype": "bfloat16",
"transformers_version": "4.38.0.dev0",
"use_cache": true,
"vocab_size": 256000
}
It seems like we force setting 'head_dim':256
, the 'dim' will be bumped to 4096. Also I've no idea where the number "12582912" comes from. It's like a puzzle to figure out how those numbers are mapped and determined. @Chillee could you elaborate?
I added support for gemma-7b. The main non-trivial component here was that head_dim * n_heads != dim
, so some parts of the model definition needed to be patched.
I'm getting 83 tok/s for fp16
cc: @guangy10 @shaahji
Can you extend support for gemma-7b as well?