pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.35k stars 484 forks source link

[example] Added gemma support #115

Open Chillee opened 4 months ago

shaahji commented 4 months ago

Can you extend support for gemma-7b as well?

shaahji commented 4 months ago

convert_hf_checkpoint.py fails with following error for gemma-7b -

Model config {'block_size': 2048, 'vocab_size': 256000, 'n_layer': 28, 'n_head': 16, 'dim': 3072, 'intermediate_size': 24576, 'n_local_heads': 16, 'head_dim': 192, 'rope_base': 10000, 'norm_eps': 1e-05}
Traceback (most recent call last):
  File "scripts/convert_hf_checkpoint.py", line 111, in <module>
    convert_hf_checkpoint(
  File "......../site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "scripts/convert_hf_checkpoint.py", line 91, in convert_hf_checkpoint
    q = permute(q, config.n_head)
  File "scripts/convert_hf_checkpoint.py", line 62, in permute
    w.view(n_head, 2, config.head_dim // 2, dim)
RuntimeError: shape '[16, 2, 96, 3072]' is invalid for input of size 12582912

The issue is with ModelArgs.head_dim being computed as 192 but HF config dictates it to be 256. I tried by forcing it but then it fails during inferencing with the following error for each layer -

size mismatch for layers.0.attention.wo.weight: copying a param with shape torch.Size([3072, 4096]) from checkpoint, the shape in current model is torch.Size([3072, 3072]).
guangy10 commented 4 months ago

Same error as @shaahji saw above. I tried with config {'dim': 3072, 'vocab_size': 256000, 'n_layer': 28, 'n_head': 16, 'n_local_heads': 16, 'intermediate_size': 24576} according to gemma-7b/config.json

{
  "architectures": [
    "GemmaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 2,
  "eos_token_id": 1,
  "head_dim": 256,
  "hidden_act": "gelu",
  "hidden_size": 3072,
  "initializer_range": 0.02,
  "intermediate_size": 24576,
  "max_position_embeddings": 8192,
  "model_type": "gemma",
  "num_attention_heads": 16,
  "num_hidden_layers": 28,
  "num_key_value_heads": 16,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.38.0.dev0",
  "use_cache": true,
  "vocab_size": 256000
}

It seems like we force setting 'head_dim':256, the 'dim' will be bumped to 4096. Also I've no idea where the number "12582912" comes from. It's like a puzzle to figure out how those numbers are mapped and determined. @Chillee could you elaborate?

Chillee commented 3 months ago

I added support for gemma-7b. The main non-trivial component here was that head_dim * n_heads != dim, so some parts of the model definition needed to be patched.

I'm getting 83 tok/s for fp16

cc: @guangy10 @shaahji