Open zhj2022 opened 3 days ago
As mentioned in https://github.com/ContextualAI/gritlm?tab=readme-ov-file#run After training, you may first have to run python scripts/reformat_statedict.py path_to_statedict to remove the model. prefix from the checkpoint,
--- this is how we run it, hope it solves the problem! (the ckpt should always be a pytorch model not safetensors via https://github.com/ContextualAI/gritlm/blob/7df395df133a75a4580aa81f8d5b197c7cfa03ee/gritlm/training/arguments.py#L150)
The original run.py saves the model in pytorch_model.bin, which cannot be loaded directly using the code provided in this repository. After replacing line 422
trainer.save_model()
in training/run.py withmodel.model.save_pretrained(training_args.output_dir)
, the saved model can be directly used for inference.However, when using fsdp, this method doesn't work anymore. For example, when I use the command
accelerate launch --config_file config_4gpusfsdp_llama.yml --num_machines 1 --num_processes 4 -m training.run --output_dir llama3test --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct --train_data training/toy_data --learning_rate 1e-5 --num_train_epochs 5 --per_device_train_batch_size 1 --dataloader_drop_last True --normalized True --temperature 0.02 --query_max_len 32 --passage_max_len 128 --train_group_size 2 --mode unified --attn cccc --attn_implementation sdpa --no_gen_gas --no_emb_gas --split_emb --bf16
, I can get a trained model which is stored in./llama3test
. But when I execute the following code for inference:an error occured:
I wonder how the authors saved the model which can be directly used in the inference code in README.md when you trained your models.