Open mostafamdy opened 2 months ago
Hi @mostafamdy ,
It seems the sampler expects the prompt in the form a List. The {{call_args}}
in base Sampler class not defined properly IMO.
Hi @mostafamdy, it seems like the guide is outdated. Thanks for bringing this up! You can refer to the Sampler API docs or the "Example Use" section on the Kaggle model card. For your usecase, there's now a simpler API for plugging-in different samplers:
import keras_nlp
model = keras_nlp.models.GemmaCausalLM('gemma_2b_en')
# Tell KerasNLP to use a "greddy" sampler. Other options are "top_k", "top_p", etc.
# See https://keras.io/api/keras_nlp/samplers/ for more info
model.compile(sampler="greedy")
output = model.generate("What is Keras?", max_length=50)
# You can also initialize a sampler to configure it for your usecase
sampler = keras_nlp.samplers.TopKSampler(k=5, temperature=0.7)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["What is Keras?"])
Does this answer your question?
Thanks @tirthasheshpatel I want to use sampler in custom loss function is it possible? 😅
Is this code correct?
def custom_loss(y_true,y_pred):
logits=y_pred
temperature=1.0
logits = ops.cast(logits, "float32")
probabilities = keras.activations.softmax(logits / temperature)
next_token=ops.argmax(probabilities, axis=-1)
print(next_token)
txt=tokenizer.detokenize(next_token)
print(f"Greedy search generated text: \n{txt}\n")
I tried to call gemma model like this
preprocessor=keras_nlp.models.GemmaPreprocessor(
tokenizer, sequence_length=SEQ_LEN,
)
model_out = gemma_lm(preprocessor([data[0]]))
and then passed model_out to custom_loss function to generate text
def custom_loss(y_true,y_pred):
logits=y_pred
temperature=1.0
logits = ops.cast(logits, "float32")
probabilities = keras.activations.softmax(logits / temperature)
next_token=ops.argmax(probabilities, axis=-1)
txt=tokenizer.detokenize(next_token)
print(f"Greedy search generated text: \n{txt}\n")
custom_loss("y_true",model_out)
But the output is different from gemma_lm.generate()
Ah OK. Your code looks good to me. You seem to be printing out the next token predictions for each input sequence which is why I guess the outputs are different. Can you check if this code generates the right output:
import keras
from keras import ops
import keras_nlp
model = keras_nlp.models.GemmaCausalLM.from_preset('gemma_2b_en')
preprocessor = model.preprocessor
tokenizer = preprocessor.tokenizer
backbone = model.backbone
def loss_fn(y_true, y_pred, prompt=None, index=None):
logits = y_pred
temperature = 1.0
logits = ops.cast(logits, "float32")
# Compute probs and next token value
probabilities = ops.softmax(logits[:, index, :], axis=-1)
next_token = ops.argmax(probabilities, axis=-1)
# Update the prompt
prompt_tokens = tokenizer.tokenize(prompt)
updated_prompt_tokens = ops.concatenate([prompt_tokens, next_token[..., None]], axis=-1)
updated_prompt = tokenizer.detokenize(updated_prompt_tokens)
# Print the updated prompt
print(f"The updated prompt is: {updated_prompt}")
# Get the inputs
prompt = ["The quick brown"]
train_data = preprocessor(prompt, sequence_length=10)
index = ops.min(ops.sum(train_data[0]['padding_mask'], axis=-1)) - 2
# Evaluate the loss function
loss_fn(train_data[1], model(train_data[0]), prompt=prompt, index=index)
# The updated prompt is: [b'The quick brown fox']
# Check outputs match
model.generate(prompt, max_length=5)
# ['The quick brown fox']
Thank you so much ❤️ I tried this code and it's working well, but I have a misunderstanding. I tried to change the index of logits
for i in range(10):
probabilities = ops.softmax(logits[:, i, :], axis=-1)
next_token = ops.argmax(probabilities, axis=-1)
updated_prompt = tokenizer.detokenize(next_token)
print(updated_prompt)
the output was :
and fox fox,' the, the
When using generate with a sequence length of 10, I received:
The quick brown fox jumps over the sleeping dog
Describe the bug
Hi I am trying sampler example here https://keras.io/examples/generative/text_generation_gpt/ in Gemma
the preprocessor in Gemma return dictionary of token_ids and padding_mask but sampler not accept dictionary
Sampler code
Error