Kipok / NeMo-Skills

A pipeline to improve skills of large language models
https://kipok.github.io/NeMo-Skills/
Apache License 2.0
185 stars 41 forks source link

FIX nemo to hf conversion #228

Closed gwarmstrong closed 2 days ago

gwarmstrong commented 6 days ago

Fixes NeMo to HF conversion--which was failing under some unpredictable circumstances due to insufficient configuration.

Kipok commented 5 days ago

@gwarmstrong can you share the error message you're getting that needs this fix? That's unexpected since we are converting models all the time and never encountered this

gwarmstrong commented 3 days ago

@Kipok

@gwarmstrong can you share the error message you're getting that needs this fix? That's unexpected since we are converting models all the time and never encountered this

yes, so when trying to convert the tiny llama model from NeMo to HF I get a bunch of errors like the following:

nemo-run/0 [rank0]: Traceback (most recent call last):
nemo-run/0 [rank0]:   File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
nemo-run/0 [rank0]:     return _run_code(code, main_globals, None,
nemo-run/0 [rank0]:   File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
nemo-run/0 [rank0]:     exec(code, run_globals)
nemo-run/0 [rank0]:   File "/nemo_run/code/nemo_skills/conversion/nemo_to_hf_llama.py", line 246, in <module>
nemo-run/0 [rank0]:     convert(
nemo-run/0 [rank0]:   File "/nemo_run/code/nemo_skills/conversion/nemo_to_hf_llama.py", line 236, in convert
nemo-run/0 [rank0]:     model.load_state_dict(checkpoint)
nemo-run/0 [rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2153, in load_state_dict
nemo-run/0 [rank0]:     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
nemo-run/0 [rank0]: RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM:
nemo-run/0 [rank0]:     size mismatch for model.layers.0.self_attn.q_proj.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([4096, 64]).
nemo-run/0 [rank0]:     size mismatch for model.layers.0.self_attn.k_proj.weight: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([1024, 64]).
nemo-run/0 [rank0]:     size mismatch for model.layers.0.self_attn.v_proj.weight: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([1024, 64]).
nemo-run/0 [rank0]:     size mismatch for model.layers.0.self_attn.o_proj.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([64, 4096]).
nemo-run/0 [rank0]:     size mismatch for model.layers.1.self_attn.q_proj.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([4096, 64]).
nemo-run/0 [rank0]:     size mismatch for model.layers.1.self_attn.k_proj.weight: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([1024, 64]).
nemo-run/0 [rank0]:     size mismatch for model.layers.1.self_attn.v_proj.weight: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([1024, 64]).
nemo-run/0 [rank0]:     size mismatch for model.layers.1.self_attn.o_proj.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([64, 4096]).

The current conversion script basically has some assumptions about the heads/head_dim that don't hold true if the model is not initialized with default parameter values for these. So it's not an issue for the conversions that are typically tested/done, but the script is not quite generic enough for all cases.

gwarmstrong commented 2 days ago

@shtoshni any chance you can test locally? I don't have enough memory on my system to complete the test_hf_nemo_conversion step needed to verify the test_nemo_hf_conversion

shtoshni commented 2 days ago

@shtoshni any chance you can test locally? I don't have enough memory on my system to complete the test_hf_nemo_conversion step needed to verify the test_nemo_hf_conversion

Passed both the tests. We're good to go.