axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.58k stars 822 forks source link

Error loading intermediate checkpoints with flash_attn_fuse_mlp: true #1558

Closed RicardoDominguez closed 4 months ago

RicardoDominguez commented 5 months ago

Please check that this issue hasn't been reported before.

Expected Behavior

When using LlamaForCausalLM models and flash_attn_fuse_mlp: true, intermediate checkpoints cannot be loaded properly using from_pretrained (some of the model weights are randomly initialized).

Current behaviour

Instead of the model checkpoint loading normally using from_pretrained, some of the model weights are randomly initialized.

Some weights of the model checkpoint at /fast/rolmedo/test/v4/all-llama-3-8k-lr_2e-5/checkpoint-6050/ were not used when initializing LlamaForCausalLM: ['model.layers.0.mlp.swiglu.w12.weight', 'model.layers.0.mlp.swiglu.w3.weight', 'model.layers.1.mlp.swiglu.w12.weight', 'model.layers.1.mlp.swiglu.w3.weight', 'model.layers.10.mlp.swiglu.w12.weight', 'model.layers.10.mlp.swiglu.w3.weight', 'model.layers.11.mlp.swiglu.w12.weight', 'model.layers.11.mlp.swiglu.w3.weight', 'model.layers.12.mlp.swiglu.w12.weight', 'model.layers.12.mlp.swiglu.w3.weight', 'model.layers.13.mlp.swiglu.w12.weight', 'model.layers.13.mlp.swiglu.w3.weight', 'model.layers.14.mlp.swiglu.w12.weight', 'model.layers.14.mlp.swiglu.w3.weight', 'model.layers.15.mlp.swiglu.w12.weight', 'model.layers.15.mlp.swiglu.w3.weight', 'model.layers.16.mlp.swiglu.w12.weight', 'model.layers.16.mlp.swiglu.w3.weight', 'model.layers.17.mlp.swiglu.w12.weight', 'model.layers.17.mlp.swiglu.w3.weight', 'model.layers.18.mlp.swiglu.w12.weight', 'model.layers.18.mlp.swiglu.w3.weight', 'model.layers.19.mlp.swiglu.w12.weight', 'model.layers.19.mlp.swiglu.w3.weight', 'model.layers.2.mlp.swiglu.w12.weight', 'model.layers.2.mlp.swiglu.w3.weight', 'model.layers.20.mlp.swiglu.w12.weight', 'model.layers.20.mlp.swiglu.w3.weight', 'model.layers.21.mlp.swiglu.w12.weight', 'model.layers.21.mlp.swiglu.w3.weight', 'model.layers.22.mlp.swiglu.w12.weight', 'model.layers.22.mlp.swiglu.w3.weight', 'model.layers.23.mlp.swiglu.w12.weight', 'model.layers.23.mlp.swiglu.w3.weight', 'model.layers.24.mlp.swiglu.w12.weight', 'model.layers.24.mlp.swiglu.w3.weight', 'model.layers.25.mlp.swiglu.w12.weight', 'model.layers.25.mlp.swiglu.w3.weight', 'model.layers.26.mlp.swiglu.w12.weight', 'model.layers.26.mlp.swiglu.w3.weight', 'model.layers.27.mlp.swiglu.w12.weight', 'model.layers.27.mlp.swiglu.w3.weight', 'model.layers.28.mlp.swiglu.w12.weight', 'model.layers.28.mlp.swiglu.w3.weight', 'model.layers.29.mlp.swiglu.w12.weight', 'model.layers.29.mlp.swiglu.w3.weight', 'model.layers.3.mlp.swiglu.w12.weight', 'model.layers.3.mlp.swiglu.w3.weight', 'model.layers.30.mlp.swiglu.w12.weight', 'model.layers.30.mlp.swiglu.w3.weight', 'model.layers.31.mlp.swiglu.w12.weight', 'model.layers.31.mlp.swiglu.w3.weight', 'model.layers.4.mlp.swiglu.w12.weight', 'model.layers.4.mlp.swiglu.w3.weight', 'model.layers.5.mlp.swiglu.w12.weight', 'model.layers.5.mlp.swiglu.w3.weight', 'model.layers.6.mlp.swiglu.w12.weight', 'model.layers.6.mlp.swiglu.w3.weight', 'model.layers.7.mlp.swiglu.w12.weight', 'model.layers.7.mlp.swiglu.w3.weight', 'model.layers.8.mlp.swiglu.w12.weight', 'model.layers.8.mlp.swiglu.w3.weight', 'model.layers.9.mlp.swiglu.w12.weight', 'model.layers.9.mlp.swiglu.w3.weight']
- This IS expected if you are initializing LlamaForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LlamaForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LlamaForCausalLM were not initialized from the model checkpoint at /fast/rolmedo/test/v4/all-llama-3-8k-lr_2e-5/checkpoint-6050/ and are newly initialized: ['model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.11.mlp.gate_proj.weight', 'model.layers.11.mlp.up_proj.weight', 'model.layers.12.mlp.down_proj.weight', 'model.layers.12.mlp.gate_proj.weight', 'model.layers.12.mlp.up_proj.weight', 'model.layers.13.mlp.down_proj.weight', 'model.layers.13.mlp.gate_proj.weight', 'model.layers.13.mlp.up_proj.weight', 'model.layers.14.mlp.down_proj.weight', 'model.layers.14.mlp.gate_proj.weight', 'model.layers.14.mlp.up_proj.weight', 'model.layers.15.mlp.down_proj.weight', 'model.layers.15.mlp.gate_proj.weight', 'model.layers.15.mlp.up_proj.weight', 'model.layers.16.mlp.down_proj.weight', 'model.layers.16.mlp.gate_proj.weight', 'model.layers.16.mlp.up_proj.weight', 'model.layers.17.mlp.down_proj.weight', 'model.layers.17.mlp.gate_proj.weight', 'model.layers.17.mlp.up_proj.weight', 'model.layers.18.mlp.down_proj.weight', 'model.layers.18.mlp.gate_proj.weight', 'model.layers.18.mlp.up_proj.weight', 'model.layers.19.mlp.down_proj.weight', 'model.layers.19.mlp.gate_proj.weight', 'model.layers.19.mlp.up_proj.weight', 'model.layers.2.mlp.down_proj.weight', 'model.layers.2.mlp.gate_proj.weight', 'model.layers.2.mlp.up_proj.weight', 'model.layers.20.mlp.down_proj.weight', 'model.layers.20.mlp.gate_proj.weight', 'model.layers.20.mlp.up_proj.weight', 'model.layers.21.mlp.down_proj.weight', 'model.layers.21.mlp.gate_proj.weight', 'model.layers.21.mlp.up_proj.weight', 'model.layers.22.mlp.down_proj.weight', 'model.layers.22.mlp.gate_proj.weight', 'model.layers.22.mlp.up_proj.weight', 'model.layers.23.mlp.down_proj.weight', 'model.layers.23.mlp.gate_proj.weight', 'model.layers.23.mlp.up_proj.weight', 'model.layers.24.mlp.down_proj.weight', 'model.layers.24.mlp.gate_proj.weight', 'model.layers.24.mlp.up_proj.weight', 'model.layers.25.mlp.down_proj.weight', 'model.layers.25.mlp.gate_proj.weight', 'model.layers.25.mlp.up_proj.weight', 'model.layers.26.mlp.down_proj.weight', 'model.layers.26.mlp.gate_proj.weight', 'model.layers.26.mlp.up_proj.weight', 'model.layers.27.mlp.down_proj.weight', 'model.layers.27.mlp.gate_proj.weight', 'model.layers.27.mlp.up_proj.weight', 'model.layers.28.mlp.down_proj.weight', 'model.layers.28.mlp.gate_proj.weight', 'model.layers.28.mlp.up_proj.weight', 'model.layers.29.mlp.down_proj.weight', 'model.layers.29.mlp.gate_proj.weight', 'model.layers.29.mlp.up_proj.weight', 'model.layers.3.mlp.down_proj.weight', 'model.layers.3.mlp.gate_proj.weight', 'model.layers.3.mlp.up_proj.weight', 'model.layers.30.mlp.down_proj.weight', 'model.layers.30.mlp.gate_proj.weight', 'model.layers.30.mlp.up_proj.weight', 'model.layers.31.mlp.down_proj.weight', 'model.layers.31.mlp.gate_proj.weight', 'model.layers.31.mlp.up_proj.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.4.mlp.gate_proj.weight', 'model.layers.4.mlp.up_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.5.mlp.up_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.6.mlp.gate_proj.weight', 'model.layers.6.mlp.up_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.7.mlp.gate_proj.weight', 'model.layers.7.mlp.up_proj.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.8.mlp.up_proj.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.9.mlp.gate_proj.weight', 'model.layers.9.mlp.up_proj.weight']

Steps to reproduce

Train using examples/llama-2/fft_optimized.yml (the important part is flash_attn_fuse_mlp: true). Then attempt to load one of the checkpoints using AutoModelForCausalLM.from_pretrained. I am using save_only_model=True however I do not think that this makes a difference. Note that with flash_attn_fuse_mlp: false everything works fine. In all cases the final model also loads fine.

Config yaml

examples/llama-2/fft_optimized.yml

Possible solution

The following can be used to fix the model .safetensors, such that intermediate checkpoints load fine.

import os
import json
from tqdm import tqdm

import torch
from safetensors import safe_open
from safetensors.torch import save_file

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--base_dir", type=str, required=True, help="Directory containing the safetensors files")
    args = parser.parse_args()

    base_dir = args.base_dir
    safetensors_files = [f for f in os.listdir(base_dir) if f.endswith(".safetensors")]
    print("Processing", len(safetensors_files), "files")

    new_weight_map = {}
    for file in tqdm(safetensors_files):
        tensors = {}

        file_path = os.path.join(base_dir, file)
        with safe_open(file_path, framework="pt", device=0) as f:
            for k in f.keys():
                tensor = f.get_tensor(k)

                if 'swiglu' in k:
                    base_name = k.split(".")
                    base_name = '.'.join(base_name[:base_name.index("swiglu") ])
                    if "w12" in k:
                        # model.layers.1.mlp.swiglu.w12.weight -> model.layers.1.mlp.gate_proj.weight
                        #                                      -> model.layers.1.mlp.up_proj.weight
                        intermediate_size = tensor.shape[0] // 2
                        w1, w2 = torch.split(tensor, intermediate_size, dim=0)

                        tensors[base_name + ".gate_proj.weight"] = w1
                        tensors[base_name + ".up_proj.weight"] = w2
                    elif "w3" in k:
                        # model.layers.1.mlp.swiglu.w3.weight -> model.layers.1.mlp.down_proj.weight
                        tensors[base_name + ".down_proj.weight"] = tensor
                    else:
                        raise ValueError(f"Unknown swiglu weight: {k}")
                else:
                    tensors[k] = tensor

        new_weight_map.update({k: file for k in tensors.keys()})
        save_file(tensors, file_path, metadata={'format': 'pt'})

    # Load the index file
    file = 'model.safetensors.index.json'
    file_path = os.path.join(base_dir, file)
    with open(file_path, 'r') as f:
        index_file = json.load(f)

    # Update the weight map
    index_file['weight_map'] = new_weight_map

    # Save the updated index file
    with open(file_path, 'w') as f:
        json.dump(index_file, f)

    print("Done")

Which Operating Systems are you using?

Python Version

3.10

axolotl branch-commit

main/60f5ce0

Acknowledgements

NanoCode012 commented 4 months ago

Is final model working with that config (or both checkpoints and final model are broken)?

I am not so clear on the reason the config was added.

RicardoDominguez commented 4 months ago

The final model works fine (the model is unfused prior to being saved at the end of training). It is only the intermediate checkpoints that have the issue.

NanoCode012 commented 4 months ago

Hm, so should we add a callback if this setting is enabled to unfuse these layers using the method you provided?

winglian commented 4 months ago

there is a workaround to use the checkpoints which is to post_process the them to unfuse the module.

from pathlib import Path

import fire
import transformers

from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer

def do_cli(config: Path = Path("examples/"), **kwargs):
    # pylint: disable=duplicate-code
    print_axolotl_text_art()
    parser = transformers.HfArgumentParser((TrainerCliArgs))
    parsed_cli_args, _ = parser.parse_args_into_dataclasses(
        return_remaining_strings=True
    )
    parsed_cli_args.merge_lora = True
    parsed_cfg = load_cfg(config, merge_lora=True, **kwargs)

    model, tokenizer = load_model_and_tokenizer(cfg=parsed_cfg, cli_args=parsed_cli_args)
    for name, module in model.named_modules():
        if hasattr(module, "_post_training"):
            module._post_training(model, name)  # pylint: disable=protected-access
    model.save_pretrained(
        str(Path(parsed_cfg.output_dir) / "post_processed"),
        safe_serialization=True,
    )

if __name__ == "__main__":
    fire.Fire(do_cli)