pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
3.85k stars 339 forks source link

Gemma2B missing lm head weight? #1062

Closed hmosousa closed 4 days ago

hmosousa commented 2 months ago

Hi there,

I was trying to import the weights from gemma_2b to transformers GemmaForCausalLM but it seems to be missing the lm_head.weight key in the state dict. Code below:

from torchtune.models.convert_weights import tune_to_hf
from torchtune.models.gemma import gemma_2b
from transformers import GemmaForCausalLM

tune_gemma = gemma_2b()
tune_state_dict = tune_gemma.state_dict()
print(len(tune_state_dict))  # 164

hf_gemma = GemmaForCausalLM.from_pretrained("google/gemma-2b")
hf_state_dict = hf_gemma.state_dict()
print(len(hf_state_dict))  # 165

tune_to_hf_state_dict = tune_to_hf(tune_state_dict, num_heads=8, num_kv_heads=1, dim=2048)
hf_gemma.load_state_dict(tune_to_hf_state_dict)  # Missing key(s) in state_dict: "lm_head.weight". 

environment:

transformers==4.41.2
torchtune @ git+https://github.com/pytorch/torchtune@9f31a882a98609452e4b357073bd251da7d341cd

Did I miss something?

SalmanMohammadi commented 2 months ago

Hey @hmosousa. Thanks so much for raising this. I've managed to replicate the issue, I think we need a little more logic for Gemma HF checkpointing in Torchtune. I've managed to fix on my end and I'll put a PR up soon.

SalmanMohammadi commented 2 months ago

Hey @hmosousa. This should now be available in torchtune nightly. Let me know if you have any further issues, or whether we can close this off.

hmosousa commented 2 months ago

Looks good to me! Thanks for the help

joecummings commented 1 month ago

Reopening this b/c we actually reverted the checkpointing logic of #1169 in #1168.

The reason is that we are separately working on providing better support for from_pretrained loading, but this one-off checkpointing logic resulted in us not being able to properly save Gemma models in E2E runs (see #1122)

felipemello1 commented 1 month ago

@joecummings

I believe that gemma is missing optimizer_in_bwd and compile parameters. Running:

tune run full_finetune_single_device --config gemma/7B_full max_steps_per_epoch=60 optimizer_in_bwd=True compile=True

yields the error

File "/home/felipemello/.conda/envs/test_release/lib/python3.10/site-packages/torchtune/_cli/run.py", line 93, in _run_single_device
    runpy.run_path(str(args.recipe), run_name="__main__")
  File "/home/felipemello/.conda/envs/test_release/lib/python3.10/runpy.py", line 289, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/home/felipemello/.conda/envs/test_release/lib/python3.10/runpy.py", line 96, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/felipemello/.conda/envs/test_release/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/felipemello/.conda/envs/test_release/lib/python3.10/site-packages/recipes/full_finetune_single_device.py", line 524, in <module>
    sys.exit(recipe_main())
  File "/home/felipemello/.conda/envs/test_release/lib/python3.10/site-packages/torchtune/config/_parse.py", line 50, in wrapper
    sys.exit(recipe_main(conf))
  File "/home/felipemello/.conda/envs/test_release/lib/python3.10/site-packages/recipes/full_finetune_single_device.py", line 519, in recipe_main
    recipe.train()
  File "/home/felipemello/.conda/envs/test_release/lib/python3.10/site-packages/recipes/full_finetune_single_device.py", line 501, in train
    self.save_checkpoint(epoch=curr_epoch)
  File "/home/felipemello/.conda/envs/test_release/lib/python3.10/site-packages/recipes/full_finetune_single_device.py", line 394, in save_checkpoint
    self._checkpointer.save_checkpoint(
  File "/home/felipemello/.conda/envs/test_release/lib/python3.10/site-packages/torchtune/utils/_checkpointing/_checkpointer.py", line 508, in save_checkpoint
    cpt_idx = self._weight_map[key]
KeyError: 'lm_head.weight'
ebsmothers commented 1 month ago

@joecummings

I believe that gemma is missing optimizer_in_bwd and compile parameters. Running:

tune run full_finetune_single_device --config gemma/7B_full max_steps_per_epoch=60 optimizer_in_bwd=True compile=True

yields the error

File "/home/felipemello/.conda/envs/test_release/lib/python3.10/site-packages/torchtune/_cli/run.py", line 93, in _run_single_device
  runpy.run_path(str(args.recipe), run_name="__main__")
File "/home/felipemello/.conda/envs/test_release/lib/python3.10/runpy.py", line 289, in run_path
  return _run_module_code(code, init_globals, run_name,
File "/home/felipemello/.conda/envs/test_release/lib/python3.10/runpy.py", line 96, in _run_module_code
  _run_code(code, mod_globals, init_globals,
File "/home/felipemello/.conda/envs/test_release/lib/python3.10/runpy.py", line 86, in _run_code
  exec(code, run_globals)
File "/home/felipemello/.conda/envs/test_release/lib/python3.10/site-packages/recipes/full_finetune_single_device.py", line 524, in <module>
  sys.exit(recipe_main())
File "/home/felipemello/.conda/envs/test_release/lib/python3.10/site-packages/torchtune/config/_parse.py", line 50, in wrapper
  sys.exit(recipe_main(conf))
File "/home/felipemello/.conda/envs/test_release/lib/python3.10/site-packages/recipes/full_finetune_single_device.py", line 519, in recipe_main
  recipe.train()
File "/home/felipemello/.conda/envs/test_release/lib/python3.10/site-packages/recipes/full_finetune_single_device.py", line 501, in train
  self.save_checkpoint(epoch=curr_epoch)
File "/home/felipemello/.conda/envs/test_release/lib/python3.10/site-packages/recipes/full_finetune_single_device.py", line 394, in save_checkpoint
  self._checkpointer.save_checkpoint(
File "/home/felipemello/.conda/envs/test_release/lib/python3.10/site-packages/torchtune/utils/_checkpointing/_checkpointer.py", line 508, in save_checkpoint
  cpt_idx = self._weight_map[key]
KeyError: 'lm_head.weight'

@felipemello1 this is because you're using the single-device recipe with the distributed config and the two are incompatible. optimizer_in_bwd and compile are both only supported on single device due to difficulty composing with FSDP (though FSDP2 should at least fix this for compile). I guess the bigger point here is that you were able to run an incompatible recipe/config pair without it complaining. I thought we would raise an error in this case.. I guess we aren't currently, so we should change this.

RdoubleA commented 4 days ago

original error was resolved, closing