allenai / open-instruct

Apache License 2.0
1.1k stars 145 forks source link

model weights after finetune can not load by vllm or huggingface #90

Closed ZiyiLiubird closed 5 months ago

ZiyiLiubird commented 7 months ago

Dear authors, thank you for opening source this great project. After I finetue my llama2-7B model using the finetune_with_accelate.sh, I can not load model weights by vllm or huggingface during inference process.

It seems that the model is saved successfully:

100%|██████████| 1300/1300 [6:12:02<00:00, 15.33s/it]12/14/2023 01:59:25 - INFO - main - Step: 1300, LR: 1.8040265793910542e-05, Loss: 0.2729374170303345 tokenizer config file saved in output/llama2_sharegpt_7B/tokenizer_config.json Special tokens file saved in output/llama2_sharegpt_7B/special_tokens_map.json added tokens file saved in output/llama2_sharegpt_7B/added_tokens.json Configuration saved in output/llama2_sharegpt_7B/config.json Configuration saved in output/llama2_sharegpt_7B/generation_config.json The model is bigger than the maximum size per checkpoint (5GB) and is going to be split in 3 checkpoint shards. You can find where each parameters has been saved in the index located at output/llama2_sharegpt_7B/model.safetensors.index.json. 100%|██████████| 1300/1300 [6:12:18<00:00, 17.18s/it]

But when I use vllm to load model weights, there is an error:

(lzy-rlhf) liuziyi@g0003:/paratera5-data/private/liuziyi/mygit/open-instruct$ bash scripts/eval/gsm.sh Loading data... Loading model and tokenizer... 2023-12-14 10:46:39,885 INFO worker.py:1673 -- Started a local Ray instance. INFO 12-14 10:46:44 llm_engine.py:73] Initializing an LLM engine with config: model='/paratera5-data/private/liuziyi/mygit/open-instruct/output/llama2_sharegpt_7B', tokenizer='/paratera5-data/private/liuziyi/mygit/open-instruct/output/llama2_sharegpt_7B', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=8, quantization=None, seed=0) INFO 12-14 10:46:44 tokenizer.py:32] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer. Traceback (most recent call last): File "/ssd/apps/anaconda/2023.03/envs/lzy-rlhf/lib/python3.9/runpy.py", line 197, in _run_module_as_main return _run_code(code, main_globals, None, File "/ssd/apps/anaconda/2023.03/envs/lzy-rlhf/lib/python3.9/runpy.py", line 87, in _run_code exec(code, run_globals) File "/paratera5-data/private/liuziyi/mygit/open-instruct/eval/gsm/run_eval.py", line 247, in main(args) File "/paratera5-data/private/liuziyi/mygit/open-instruct/eval/gsm/run_eval.py", line 78, in main model = vllm.LLM( File "/ssd/apps/anaconda/2023.03/envs/lzy-rlhf/lib/python3.9/site-packages/vllm/entrypoints/llm.py", line 93, in init self.llm_engine = LLMEngine.from_engine_args(engine_args) File "/ssd/apps/anaconda/2023.03/envs/lzy-rlhf/lib/python3.9/site-packages/vllm/engine/llm_engine.py", line 246, in from_engine_args engine = cls(engine_configs, File "/ssd/apps/anaconda/2023.03/envs/lzy-rlhf/lib/python3.9/site-packages/vllm/engine/llm_engine.py", line 107, in init self._init_workers_ray(placement_group) File "/ssd/apps/anaconda/2023.03/envs/lzy-rlhf/lib/python3.9/site-packages/vllm/engine/llm_engine.py", line 194, in _init_workers_ray self._run_workers( File "/ssd/apps/anaconda/2023.03/envs/lzy-rlhf/lib/python3.9/site-packages/vllm/engine/llm_engine.py", line 750, in _run_workers self._run_workers_in_batch(workers, method, args, kwargs)) File "/ssd/apps/anaconda/2023.03/envs/lzy-rlhf/lib/python3.9/site-packages/vllm/engine/llm_engine.py", line 727, in _run_workers_in_batch all_outputs = ray.get(all_outputs) File "/ssd/apps/anaconda/2023.03/envs/lzy-rlhf/lib/python3.9/site-packages/ray/_private/auto_init_hook.py", line 24, in auto_init_wrapper return fn(*args, *kwargs) File "/ssd/apps/anaconda/2023.03/envs/lzy-rlhf/lib/python3.9/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper return func(args, kwargs) File "/ssd/apps/anaconda/2023.03/envs/lzy-rlhf/lib/python3.9/site-packages/ray/_private/worker.py", line 2563, in get raise value.as_instanceof_cause() ray.exceptions.RayTaskError(AssertionError): ray::RayWorkerVllm.execute_method() (pid=1415438, ip=10.232.14.3, actor_id=735b979085096463debd933f01000000, repr=<vllm.engine.ray_utils.RayWorkerVllm object at 0x14675a4a8a60>) File "/ssd/apps/anaconda/2023.03/envs/lzy-rlhf/lib/python3.9/site-packages/vllm/engine/ray_utils.py", line 31, in execute_method return executor(*args, **kwargs) File "/ssd/apps/anaconda/2023.03/envs/lzy-rlhf/lib/python3.9/site-packages/vllm/worker/worker.py", line 72, in load_model self.model_runner.load_model() File "/ssd/apps/anaconda/2023.03/envs/lzy-rlhf/lib/python3.9/site-packages/vllm/worker/model_runner.py", line 36, in load_model self.model = get_model(self.model_config) File "/ssd/apps/anaconda/2023.03/envs/lzy-rlhf/lib/python3.9/site-packages/vllm/model_executor/model_loader.py", line 98, in get_model model.load_weights(model_config.model, model_config.download_dir, File "/ssd/apps/anaconda/2023.03/envs/lzy-rlhf/lib/python3.9/site-packages/vllm/model_executor/models/llama.py", line 336, in load_weights weight_loader(param, loaded_weight) File "/ssd/apps/anaconda/2023.03/envs/lzy-rlhf/lib/python3.9/site-packages/vllm/model_executor/layers/vocab_parallel_embedding.py", line 80, in weight_loader assert loaded_weight.shape[parallel_dim] == self.num_embeddings AssertionError

I will be greatly thankful if you can give me some insight for how to deal with this issus.

hamishivi commented 7 months ago

Hi, I think this is an issue with safetensors and us using an older version of accelerate due to LR scheduler bugs. Try setting safe_serialization to False in save_with_accelerate like such:

def save_with_accelerate(accelerator, model, tokenizer, output_dir, args):
    unwrapped_model = accelerator.unwrap_model(model)
    # When doing multi-gpu training, we need to use accelerator.get_state_dict(model) to get the state_dict.
    # Otherwise, sometimes the model will be saved with only part of the parameters.
    # Also, accelerator needs to use the wrapped model to get the state_dict.
    state_dict = accelerator.get_state_dict(model)
    if args.use_lora:
        # When using lora, the unwrapped model is a PeftModel, which doesn't support the is_main_process 
        # and has its own save_pretrained function for only saving lora modules.
        # We have to manually specify the is_main_process outside the save_pretrained function.
        if accelerator.is_main_process:
            unwrapped_model.save_pretrained(output_dir, state_dict=state_dict)
    else:
        # don't use safetensors for saving for now
        unwrapped_model.save_pretrained(
            output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, state_dict=state_dict,
            safe_serialization=False
        )

We will make this fix in the code shortly!