axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.97k stars 879 forks source link

Llama will not save properly #1947

Open mfirth-truffle opened 1 month ago

mfirth-truffle commented 1 month ago

Please check that this issue hasn't been reported before.

Expected Behavior

When my model completes and I try to do inference with it it should load without error

Current behaviour

My model is missing parameters and thus errors out when loading

[2024-10-06 21:07:57,939] [ERROR] [axolotl.load_model:808] [PID:45370] [RANK:0] Error(s) in loading state_dict for LlamaForCausalLM:
        size mismatch for model.embed_tokens.weight: copying a param with shape torch.Size([131344896]) from checkpoint, the shape in current model is torch.Size([128266, 4096]).
        size mismatch for model.norm.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096]).
        size mismatch for lm_head.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([128266, 4096]).
        You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.
Traceback (most recent call last):
  File "/root/axolotl/src/axolotl/utils/models.py", line 710, in load_model
    model = AutoModelLoader.from_pretrained(
  File "/root/.venv/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 564, in from_pretrained
    return model_class.from_pretrained(
  File "/root/.venv/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4014, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/root/.venv/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4559, in _load_pretrained_model
    raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM:
        size mismatch for model.embed_tokens.weight: copying a param with shape torch.Size([131344896]) from checkpoint, the shape in current model is torch.Size([128266, 4096]).
        size mismatch for model.norm.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096]).
        size mismatch for lm_head.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([128266, 4096]).
        You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

Steps to reproduce

Train a model with my config, and any pre-tokenized dataset, and then try to run it

Config yaml

base_model: meta-llama/Llama-3.1-8B-Instruct
tokenizer_config: ./tokenizer
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_llama_derived_model: true

save_safetensors: true

datasets:
  - path: ./processed_data.jsonl
    ds_type: json
    split: train[]
    type:

dataset_prepared_path: ./last_run_prepared

output_dir: ./models
sequence_len: 8192

wandb_project: llama-3.1-8b-inst
wandb_name: llama-3.1-8b-inst

gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
learning_rate: 2e-5

bf16: auto
fp16:
tf32: false

logging_steps: 10
xformers_attention:
flash_attention: true

warmup_steps: 100
evals_per_epoch: 2
save_steps: 1
weight_decay: 0.0

fsdp:
  - full_shard
  - auto_wrap
fsdp_config:
  fsdp_limit_all_gathers: true
  fsdp_sync_module_states: true
  fsdp_offload_params: false
  fsdp_use_orig_params: true
  fsdp_cpu_ram_efficient_loading: false
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_backward_prefetch: BACKWARD_PRE

Possible solution

No response

Which Operating Systems are you using?

Python Version

3.10

axolotl branch-commit

main

Acknowledgements

mfirth-truffle commented 1 month ago

For any future people who may stumble across this, just don't use FSDP

NanoCode012 commented 1 month ago

Did the model total size bloat / appear much different from the original's size?

bursteratom commented 1 month ago

Hi @mfirth-truffle

I used a similar configuration file to train the model, and was able to do inference without running into error.

Image 10-18-24 at 10 41 AM

I made sure that my FSDP configurations are the same as your yml.

Here is mine:

base_model: meta-llama/Llama-3.1-8B-Instruct

save_safetensors: true

datasets:
  - path: teknium/GPT4-LLM-Cleaned
    type: alpaca

dataset_prepared_path: ./last_run_prepared

output_dir: ./outputs/fft-out
sequence_len: 8192

gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
learning_rate: 2e-5

bf16: auto
fp16:
tf32: false

logging_steps: 10
xformers_attention:
flash_attention: true

warmup_steps: 10
evals_per_epoch: 2
save_steps: 2
max_steps: 5
weight_decay: 0.0

fsdp:
  - full_shard
  - auto_wrap
fsdp_config:
  fsdp_limit_all_gathers: true
  fsdp_sync_module_states: true
  fsdp_offload_params: false
  fsdp_use_orig_params: true
  fsdp_cpu_ram_efficient_loading: false
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_backward_prefetch: BACKWARD_PRE

special_tokens:
  pad_token: "<|end_of_text|>"

Perhaps your issue has to do with your (presumably) customised tokenizer config? Would you be able to provide me that so I can dig deeper? Thanks!