OpenNMT / CTranslate2

Fast inference engine for Transformer models
https://opennmt.net/CTranslate2
MIT License
3.02k stars 269 forks source link

Gemma model - help needed #1728

Open carolinaxxxxx opened 2 weeks ago

carolinaxxxxx commented 2 weeks ago

Can any colleague help with the example of interference with the Gemma model in CTranslate2? Unfortunately, there is no information about this model in the documentation.

Thx

minhthuc2502 commented 2 weeks ago

Hello, I will update the doc in the future. BTW, you can convert the Gemma like mention in the llama documentation.

ct2-transformers-converter --model google/gemma-7b --output_dir gemma_ct2

Then you can try with script:

import ctranslate2
import transformers

generator = ctranslate2.Generator("gemma_ct2")
tokenizer = transformers.AutoTokenizer.from_pretrained("google/gemma-7b")

b_inst = '<start_of_turn>'
e_inst = '<end_of_turn>'
intput = 'Ask something'
prompt = b_inst + 'user' + input + e_inst + '\n' + b_inst + model
tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt))
results = generator.generate_batch([tokens], max_length=30, sampling_topk=10)
print(tokenizer.decode(results[0].sequences_ids[0]))
carolinaxxxxx commented 1 week ago

@minhthuc2502 - Does CTranslate2 support openchat models e.g. openchat/openchat-3.5-0106-gemma? I managed to perform the conversion to ct2, but I can't "make" it work properly? THX

minhthuc2502 commented 1 week ago

What is the error? I see the defined architecture in openchat model is GemmaForCausalLM so I think it should work.

carolinaxxxxx commented 1 week ago

@minhthuc2502 I use below code for test:

import ctranslate2
import transformers

generator = ctranslate2.Generator("/test/openchat35gemma", device="cuda", device_index=1)
tokenizer = transformers.AutoTokenizer.from_pretrained("/test/openchat35gemma")

prompt = f"GPT4 Correct User: Hello<end_of_turn>GPT4 Correct Assistant: Hi<end_of_turn>GPT4 Correct User: How are you today?<end_of_turn>GPT4 Correct Assistant: "
tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt))
results = generator.generate_batch([tokens], max_length=4096, sampling_temperature=0.1, sampling_topk=1, sampling_topp=0.1, include_prompt_in_result=False)
print(tokenizer.decode(results[0].sequences_ids[0]))

The result is random characters. Where did I go wrong? tokenizer.model is on path. THX.