young-geng / EasyLM

Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
Apache License 2.0
2.33k stars 247 forks source link

use streaming_train_state to convert_easylm_to_hf find opt_state.1.0.count #72

Closed xzqxnet0990 closed 1 year ago

xzqxnet0990 commented 1 year ago

I used pretrain_llama_7b.sh script to pretrain a easylm model, then I got multiple streaming_train_states and checkpoint files. Then I use the convert_easylm_to_hf.py script to convert streaming_train_state files to hf format, I met a problem about loaded[f"transformer.h.{layer_i}.attention.wq.kernel"]

Error messages:

File "/root/miniconda3/envs/EasyLM/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/root/miniconda3/envs/EasyLM/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/data/ketadb/EasyLM/EasyLM/models/llama/convert_easylm_to_hf.py", line 299, in mlxu.run(main) File "/root/miniconda3/envs/EasyLM/lib/python3.8/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/root/miniconda3/envs/EasyLM/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/data/ketadb/EasyLM/EasyLM/models/llama/convert_easylm_to_hf.py", line 291, in main write_model( File "/data/ketadb/EasyLM/EasyLM/models/llama/convert_easylm_to_hf.py", line 152, in write_model loaded[f"transformer.h.{layer_i}.attention.wq.kernel"] KeyError: 'transformer.h.0.attention.wq.kernel'

The error happend in file convert_easylm_to_hf.py about line 136~151. When I print the dict loaded[], I met a wired problems. There are two types of transformer matrix in one streaming_train_state files.

One is

'params.params.lm_head.kernel': tensor([[ 0.1108, -0.1748, 1.3438, ..., -0.5273, -0.3320, -0.5391], [ 0.1084, -0.1680, 1.3047, ..., -0.5195, -0.3242, -0.5273], [-0.0574, 0.1104, -0.6836, ..., 0.1118, 0.0212, 0.2344], ..., [-0.0364, -0.1108, 0.6680, ..., -0.3320, -0.2178, -0.3848], [ 0.0688, -0.0488, 0.6758, ..., -0.3262, -0.2412, -0.3691], [ 0.1108, -0.1748, 1.3438, ..., -0.5273, -0.3320, -0.5391]], dtype=torch.float16), 'params.params.transformer.h.0.attention.wk.kernel'

The key begins with params.params.

The other is

opt_state.1.0.mu.params.transformer.h.1.attention_norm.kernel': tensor([ 1.5354e-04, -1.2589e-04, -1.2112e-04, -7.4863e-05, 2.8491e-05, -2.6703e-05, -1.7643e-04, 2.0504e-04, -2.7084e-04, 6.8188e-05, -3.6478e-05, 1.8501e-04, 3.4332e-05, -2.9373e-04, 7.2956e-05, -1.2934e-05, 2.8968e-05, -6.4373e-05, -2.1458e-05, -1.3065e-04, -3.6478e-05, 2.7061e-05, 1.0300e-04, 2.1458e-05, -2.2316e-04, -2.8419e-04, 1.2577e-05, -7.5340e-05, 2.7657e-04, -2.1577e-05, -1.0157e-04, -1.4687e-04, 5.8413e-05, 1.0824e-04, -8.4877e-05, -1.0109e-04, -2.2411e-04, 8.8692e-05, 1.3161e-04, 1.4591e-04, -1.5378e-05, -2.8992e-04, -5.8889e-05, -1.0204e-04, -2.3842e-04, -2.4605e-04, -2.9802e-06, -1.0347e-04, -1.1826e-04, -9.9659e-05, -3.6621e-04, -2.4223e-04, 1.1921e-06, 8.1539e-05, -1.0347e-04, 1.6689e-06, 2.5749e-04, -2.7466e-04, -2.6703e-04, 2.1172e-04, -1.2040e-05, 2.0218e-04, 8.6308e-05, -3.0518e-04]

The key begins with opt_state.1.0.mu.

However I could not find any code about opt_state.1.0.mu in the whole project. If I want to convert streaming_train_state files, I need to change the source like loaded[f"params.params.transformer.h.{layer_i}.attention.wq.kernel"]

This is my process code python -m EasyLM.models.llama.convert_easylm_to_hf --load_checkpoint='params::../open_llama_3m/d521488bc3194913871dd8f3617e8dbf/streaming_train_state_245000' --tokenizer_path='../llama-13b-lora-hf/tokenizer.model' --model_size='3m' --output_dir='../openllama-3m-hf'

young-geng commented 1 year ago

This is expected. The two matrices you see are AdamW optimizer states. The correct option to convert a trainstate with optimizer state to huggingface is using --load_checkpoint="trainstate_params::/path/to/trainstate instead of --load_checkpoint="params::/path/to/trainstate. Please see the checkpoint documentation for more details.

xzqxnet0990 commented 1 year ago

I create my own LLAMA_STANDARD_CONFIGS with

'3m': { 'vocab_size': 49953, 'hidden_size': 64, 'intermediate_size': 128, 'num_hidden_layers': 4, 'num_attention_heads': 8, 'max_sequence_length': 2048, 'initializer_range': 0.02, 'rms_norm_eps': 1e-6, 'use_cache': True, 'tie_word_embeddings': False, }, in llama_model.py.

and the same config in convert_easylm_to_hf.py

'3m': { 'dim': 64, 'intermediate_size': 128, 'n_layers': 4, 'n_heads': 8, 'norm_eps': 1e-6, },

Because I use a different tokenizer.model file not the official llama tokenizer.model file. It will cause the mismatched_sizes problems

File "/root/miniconda3/envs/EasyLM/lib/python3.8/site-packages/transformers/modeling_utils.py", line 3278, in _load_pretrained_model raise RuntimeError(f"Error(s) in loading state_dict for {model.class.name}:\n\t{error_msg}") RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM: size mismatch for model.embed_tokens.weight: copying a param with shape torch.Size([49954, 64]) from checkpoint, the shape in current model is torch.Size([32000, 64]). size mismatch for lm_head.weight: copying a param with shape torch.Size([49954, 64]) from checkpoint, the shape in current model is torch.Size([32000, 64]). You may consider adding ignore_mismatched_sizes=True in the model from_pretrained method.

However, I change the source code with ignore_mismatched_sizes=True

model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16, ignore_mismatched_sizes=True)

It will cause another error,

Traceback (most recent call last): File "/root/miniconda3/envs/EasyLM/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/root/miniconda3/envs/EasyLM/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/data/ketadb/EasyLM/EasyLM/models/llama/convert_easylm_to_hf.py", line 297, in mlxu.run(main) File "/root/miniconda3/envs/EasyLM/lib/python3.8/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/root/miniconda3/envs/EasyLM/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/data/ketadb/EasyLM/EasyLM/models/llama/convert_easylm_to_hf.py", line 289, in main write_model( File "/data/ketadb/EasyLM/EasyLM/models/llama/convert_easylm_to_hf.py", line 205, in write_model model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16, ignore_mismatched_sizes=True) File "/root/miniconda3/envs/EasyLM/lib/python3.8/site-packages/transformers/modeling_utils.py", line 2881, in from_pretrained ) = cls._load_pretrained_model( File "/root/miniconda3/envs/EasyLM/lib/python3.8/site-packages/transformers/modeling_utils.py", line 3218, in _load_pretrained_model mismatched_keys += _find_mismatched_keys( File "/root/miniconda3/envs/EasyLM/lib/python3.8/site-packages/transformers/modeling_utils.py", line 3141, in _find_mismatched_keys and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape KeyError: 'model.layers.1.self_attn.q_proj.weight'

young-geng commented 1 year ago

You might want to change the vocab size in the conversion script in addition to the model file.

xzqxnet0990 commented 1 year ago

It is really easy to fix this problem, I add the vocab_size into LLAMA_CONFIGS

'3b': { 'vocab_size': 32000, 'dim': 3200, 'intermediate_size': 8640, 'n_layers': 26, 'n_heads': 32, 'norm_eps': 1e-6, },

and loading the LLamaConfig with vocab_size

config = LlamaConfig( hidden_size=dim, intermediate_size=params["intermediate_size"], num_attention_heads=params["n_heads"], num_hidden_layers=params["n_layers"], rms_norm_eps=params["norm_eps"], vocab_size=params["vocab_size"] ) Then I test it, it worked.