Closed RicardoDominguez closed 4 months ago
Is final model working with that config (or both checkpoints and final model are broken)?
I am not so clear on the reason the config was added.
The final model works fine (the model is unfused prior to being saved at the end of training). It is only the intermediate checkpoints that have the issue.
Hm, so should we add a callback if this setting is enabled to unfuse these layers using the method you provided?
there is a workaround to use the checkpoints which is to post_process the them to unfuse the module.
from pathlib import Path
import fire
import transformers
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.merge_lora = True
parsed_cfg = load_cfg(config, merge_lora=True, **kwargs)
model, tokenizer = load_model_and_tokenizer(cfg=parsed_cfg, cli_args=parsed_cli_args)
for name, module in model.named_modules():
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
model.save_pretrained(
str(Path(parsed_cfg.output_dir) / "post_processed"),
safe_serialization=True,
)
if __name__ == "__main__":
fire.Fire(do_cli)
Please check that this issue hasn't been reported before.
Expected Behavior
When using LlamaForCausalLM models and
flash_attn_fuse_mlp: true
, intermediate checkpoints cannot be loaded properly usingfrom_pretrained
(some of the model weights are randomly initialized).Current behaviour
Instead of the model checkpoint loading normally using
from_pretrained
, some of the model weights are randomly initialized.Steps to reproduce
Train using
examples/llama-2/fft_optimized.yml
(the important part isflash_attn_fuse_mlp: true
). Then attempt to load one of the checkpoints usingAutoModelForCausalLM.from_pretrained
. I am usingsave_only_model=True
however I do not think that this makes a difference. Note that withflash_attn_fuse_mlp: false
everything works fine. In all cases the final model also loads fine.Config yaml
examples/llama-2/fft_optimized.yml
Possible solution
The following can be used to fix the model .safetensors, such that intermediate checkpoints load fine.
Which Operating Systems are you using?
Python Version
3.10
axolotl branch-commit
main/60f5ce0
Acknowledgements