meta-llama / llama3

The official Meta Llama 3 GitHub site
Other
26.21k stars 2.95k forks source link

[Query] How to make text generation stop using certain stop_strings in LLama3 in Huggingface ? #228

Open Acejoy opened 3 months ago

Acejoy commented 3 months ago

hey all, I am using huggingface's transformers' library to do text generations using LLama3 8B Instruct Model. I want to stop my generation upon encountering certain strings like ('\n') . Is there a way to achieve this in transformers library? I looked into StoppingCriteria, but I couldn't get it running. Also, the llama3 tokenizer returns None when I run llama3_tokenizer.convert_tokens_toids(['\n'])

Any help is appreciated. Thanks.

subramen commented 3 months ago

You might want to ask on the transformers repo as this is specific to their API. I thought model.generate would have a stop arg but I don't see it on their docs.

You could use the eos_token_id arg like

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>"),
    tokenizer.convert_tokens_to_ids("my-stop-str")
]
model.generate(... , eos_token_id=terminators, ...)

cc @osanseviero for HF expertise

Acejoy commented 3 months ago

Thanks for the reply. Will Look into the transformers library

osanseviero commented 3 months ago

Setting eos_token_id is the right approach :+1: Just make sure you are using the right token as the tokenizer expects it :) (e.g. spaces at beginning, etc)

Acejoy commented 3 months ago

Setting eos_token_id is the right approach 👍 Just make sure you are using the right token as the tokenizer expects it :) (e.g. spaces at beginning, etc)

Could you give an example?(specifically for '\n') I tried the same, but was not successful. Thanks