EleutherAI / lm-evaluation-harness

A framework for few-shot evaluation of language models.
https://www.eleuther.ai
MIT License
6.88k stars 1.83k forks source link

Inconsistencies with Model Performance when using LoRA Requests & vLLM #2432

Open ckgresla opened 1 week ago

ckgresla commented 1 week ago

Im curious about other folks experiences with using vllm and lm-eval -- when using a command like the following:

export TASK="custom-function-calling-task"
export lora_adapter_dir="/some/path/where/an/adapter_config/lives"

lm_eval --model vllm    \
    --model_args  pretrained=meta-llama/Llama-3.1-8B-Instruct,enable_lora=True,lora_local_path=$lora_adapter_dir,dtype=auto     \
    --task $TASK    \
    --limit 30    \
    --batch_size auto                \
    --log_samples                    \
    --output_path out

Which I believe is calling into the api correctly (see discussion here). The results of using this command to evaluate an adapter/model are as if there was no adapter included in the evaluation. As in running an evaluation with a fine-tuned adapter and vllm results in the exact same (greedy decoded) results as running the evaluation with just the base model. I have confirmed that the adapter does get loaded, and that at runtime we have a valid value for self.lora_request inside of lm_eval/models/vllm_causallms.py. To verify that the adapter weights weren't "borked" in my case, I merged it with the base model and re-ran the evaluation with the merged model -- the results were quite different from the other greedy results.

Here are some relevant values of model_args I used and screenshots of corresponding eval results:

Has anyone else came across this sort of issue? Might there be discontinuity between the way LoRA requests are issued in lm-eval and the integrated vLLM?

As an aside, would someone know exactly where enable_lora get used? It presumably is parsed as a kwarg in lm_eval/models/vllm_causallms.py but then I don't think it gets applied/referenced anywhere in the LLM init, reference- vllm/entrypoints/llm.py.

ckgresla commented 1 week ago

Mmmm, the issue could be here:

# commit: 67a990e in lm_eval/models/vllm_causallms.py
        if self.data_parallel_size > 1:
            # vLLM hangs if tensor_parallel > 1 and resources are set in ray.remote
            # also seems to only work with decorator and not with ray.remote() fn
            # see https://github.com/vllm-project/vllm/issues/973
            # note: this has changed on 0.3.3, and it only works now if num_gpus are set.
            # but then tensor_parallel breaks
            @ray.remote
            def run_inference_one_model(
                model_args: dict, sampling_params, requests: List[List[int]]
            ):
                llm = LLM(**model_args)
                return llm.generate(
                    prompt_token_ids=requests, sampling_params=sampling_params
                )

            # dispatch requests to all self.data_parallel_size workers, in interleaved fashion
            # interleaved important to balance context lengths across workers
            requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
            inputs = ((self.model_args, sampling_params, req) for req in requests)
            object_refs = [run_inference_one_model.remote(*x) for x in inputs]
            results = ray.get(object_refs)
            # Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
            ray.shutdown()
            # flatten results
            return undistribute(results)

        if self.lora_request is not None:
            outputs = self.model.generate(
                prompt_token_ids=requests,
                sampling_params=sampling_params,
                use_tqdm=True if self.batch_size == "auto" else False,
                lora_request=self.lora_request,
            )

when using data parallelism and vllm, it appears that lora requests are not distributed to all of the workers, the contents of run_inference_one_model distributes base model requests.