predibase / lorax

Multi-LoRA inference server that scales to 1000s of fine-tuned LLMs
https://loraexchange.ai
Apache License 2.0
1.87k stars 126 forks source link

Llama3-8b-Instruct won't stop generating #442

Open ekim322 opened 2 months ago

ekim322 commented 2 months ago

System Info

lorax-client==0.5.0

Information

Tasks

Reproduction

.

Expected behavior

I use below code to get LLM response.

pb = Predibase(api_token=os.environ.get("PREDIBASE_API_TOKEN"))
lorax_client = pb.deployments.client("llama-3-8b-instruct")
lorax_client.generate(
    prompt,
    adapter_id="ekim322/cpAdapter",
    adapter_source="hub",
    api_token=os.environ.get("HF_W_TOKEN"),
    max_new_tokens=512,
).generated_text

Llama3 keeps generating tokens until max_new_tokens. It looks like the eos_token_id is never registered. I had similar issue running locally, and updating transformers to >4.40 solved the issue. Issue related to llama3b and llama3b-instruct having different eos_tokens.

I tried setting stop_sequence

lorax_client.generate(
    prompt,
    adapter_id="ekim322/cpAdapter",
    adapter_source="hub",
    api_token=os.environ.get("HF_W_TOKEN"),
    max_new_tokens=512,
    stop_sequences=['<|end_of_text|>', '<|eot_id|>']
).generated_text

but this returns empty string response. What is the proper way to set stopping tokens?

Am I setting up Predibase correctly?

tgaddair commented 1 month ago

Hey @ekim322, we recently made some changes to fix this in #456. Can you try with the latest LoRAX version to see if the error persists?

micholeodon commented 1 week ago

@ekim322 does the change mentioned by @tgaddair fixes your problem? If yes, tell how, please. If no, but you have found some other solution, please share it :)

I have exactly the same issue with LoRAX and Llama3-8B-Instruct - model tries to use up the max_new_tokens limit.

ekim322 commented 1 week ago

@micholeodon Updating the transformers library and training the model solved the issue for me (I am not 100% sure if this was the fix or #456, but the inference is working fine for me now - I think my issue got solved on my end before #456 was implemented).

I think there were some errors with Llama 3 Instruct chat template in the older version of Transformers library if I remember correctly.

I'd recommend updating all the libraries and training the model again.

micholeodon commented 1 week ago

Thank very much for your comment.

Speaking of training, I use the same version of transformers to (1) ask model "manually" (via transformers.pipeline) and to (2) ask model via LoRAX.

(1) works like a charm (2) is greedy and use up all tokens

If method (1) works then I don't expect updating library or training model again could help. What do you think?

micholeodon commented 1 week ago

I have just solved the problem. I have used proper chat template for Llama3-8B-Instruct. Essentially, make sure that the string you pass to LoRAX inputs parameter of the /generate endpoint has proper tokens like <|begin_of_text|>, <|eot_id|>, <|start_header_id|>, <|end_header_id|>, <|end_of_text|>.

See:

ekim322 commented 1 week ago

My model was working fine with transformers, but not with lorax (same issue as @micholeodon).

When I last checked, Llama3 and llama3_instruct used different tokens (e.g. <|end_of_text|> for base model eos token and <|eot_id|> for instruct model eos) I am not sure if the older transformer library versions handled this properly in the adapter/generation configs. Inference was still working fine using transformers, just not with Lorax.

I updated transformers library and retrained the model - inference worked fine with Lorax right away, I didn't have to make any adjustments.