Lightning-AI / litgpt

20+ high-performance LLMs with recipes to pretrain, finetune and deploy at scale.
https://lightning.ai
Apache License 2.0
10.54k stars 1.04k forks source link

Gemma 2B weights seem to have changed #1665

Open rasbt opened 2 months ago

rasbt commented 2 months ago

Bug description

It seems that they updated the Gemma v1 2B weights. Something to look into:

⚡ main ~/litgpt litgpt chat checkpoints/google/gemma-2b
{'access_token': None,
 'checkpoint_dir': PosixPath('checkpoints/google/gemma-2b'),
 'compile': False,
 'max_new_tokens': 50,
 'multiline': False,
 'precision': None,
 'quantize': None,
 'temperature': 0.8,
 'top_k': 50,
 'top_p': 1.0}
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/bin/litgpt", line 8, in <module>
    sys.exit(main())
  File "/teamspace/studios/this_studio/litgpt/litgpt/__main__.py", line 71, in main
    CLI(parser_data)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jsonargparse/_cli.py", line 119, in CLI
    return _run_component(component, init.get(subcommand))
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jsonargparse/_cli.py", line 204, in _run_component
    return component(**cfg)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/teamspace/studios/this_studio/litgpt/litgpt/chat/base.py", line 260, in main
    load_checkpoint(fabric, model, checkpoint_path)
  File "/teamspace/studios/this_studio/litgpt/litgpt/utils.py", line 362, in load_checkpoint
    model.load_state_dict(state_dict, strict=strict)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2153, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for GPT:
        Missing key(s) in state_dict: "transformer.h.0.norm_2.weight", "transformer.h.1.norm_2.weight", "transformer.h.2.norm_2.weight", "transformer.h.3.norm_2.weight", "transformer.h.4.norm_2.weight", "transformer.h.5.norm_2.weight", "transformer.h.6.norm_2.weight", "transformer.h.7.norm_2.weight", "transformer.h.8.norm_2.weight", "transformer.h.9.norm_2.weight", "transformer.h.10.norm_2.weight", "transformer.h.11.norm_2.weight", "transformer.h.12.norm_2.weight", "transformer.h.13.norm_2.weight", "transformer.h.14.norm_2.weight", "transformer.h.15.norm_2.weight", "transformer.h.16.norm_2.weight", "transformer.h.17.norm_2.weight". 
        Unexpected key(s) in state_dict: "transformer.h.0.post_attention_norm.weight", "transformer.h.1.post_attention_norm.weight", "transformer.h.2.post_attention_norm.weight", "transformer.h.3.post_attention_norm.weight", "transformer.h.4.post_attention_norm.weight", "transformer.h.5.post_attention_norm.weight", "transformer.h.6.post_attention_norm.weight", "transformer.h.7.post_attention_norm.weight", "transformer.h.8.post_attention_norm.weight", "transformer.h.9.post_attention_norm.weight", "transformer.h.10.post_attention_norm.weight", "transformer.h.11.post_attention_norm.weight", "transformer.h.12.post_attention_norm.weight", "transformer.h.13.post_attention_norm.weight", "transformer.h.14.post_attention_norm.weight", "transformer.h.15.post_attention_norm.weight", "transformer.h.16.post_attention_norm.weight", "transformer.h.17.post_attention_norm.weight". 

We can either fix or remove these. Because there's Gemma 2, not sure why someone would care about Gemma 1. What do you think @Andrei-Aksionov ?

What operating system are you using?

Unknown

LitGPT Version

Andrei-Aksionov commented 2 months ago

I recommend investigating it.

Just by quickly checking I saw that neither HF modeling file, nor the weights were updated. The error message says that it didn't get norm_2.weight but got post_attention_norm.weight, although it has to be mapped perfectly: https://github.com/Lightning-AI/litgpt/blob/b0ea1772f7498e2e8cc58d2ac1640b0255ced757/litgpt/scripts/convert_hf_checkpoint.py#L150

Well, maybe not perfect, because I see that unexpected weights start with transformer.h, while it has to be model.layers.

Something is off.