keras-team / keras-io

Keras documentation, hosted live at
Apache License 2.0
2.69k stars 2.01k forks source link

TypeError: Sampler.__call__() got an unexpected keyword argument 'end_token_id' #1840

Closed okyanusoz closed 1 month ago

okyanusoz commented 2 months ago

Issue Type




Keras Version


Custom Code


OS Platform and Distribution


Python version


GPU model and memory

Colab T4 GPU

Current Behavior?

I've been trying to run the English-to-Spanish translation with KerasNLP but I get stuck at the prediction/eval part:


TypeError                                 Traceback (most recent call last)

[<ipython-input-17-7f8fd291e436>](https://localhost:8080/#) in <cell line: 35>()
     35 for i in range(2):
     36     input_sentence = random.choice(test_eng_texts)
---> 37     translated = decode_sequences([input_sentence])
     38     translated = translated.numpy()[0].decode("utf-8")
     39     translated = (

[<ipython-input-17-7f8fd291e436>](https://localhost:8080/#) in decode_sequences(input_sentences)
     22     prompt = ops.concatenate((start, pad), axis=-1)
---> 24     generated_tokens = keras_nlp.samplers.GreedySampler()(
     25         next,
     26         prompt,

TypeError: Sampler.__call__() got an unexpected keyword argument 'end_token_id'

The model trains successfully, though.

Standalone code to reproduce the issue or tutorial link

Colab: (Runtime: T4 GPU)

Relevant log output


TypeError                                 Traceback (most recent call last)

<ipython-input-17-7f8fd291e436> in <cell line: 35>()
     35 for i in range(2):
     36     input_sentence = random.choice(test_eng_texts)
---> 37     translated = decode_sequences([input_sentence])
     38     translated = translated.numpy()[0].decode("utf-8")
     39     translated = (

<ipython-input-17-7f8fd291e436> in decode_sequences(input_sentences)
     22     prompt = ops.concatenate((start, pad), axis=-1)
---> 24     generated_tokens = keras_nlp.samplers.GreedySampler()(
     25         next,
     26         prompt,

TypeError: Sampler.__call__() got an unexpected keyword argument 'end_token_id'
Tylman-M commented 2 months ago

This is how I fixed it for my code. I highly suspect there's a better way to fix this, but in short three things needed to be fixed:

  1. Change 'end_token_id' to 'stop_token_ids'
  2. Wrap the token id argument in a list
  3. add '.to_tensor()' to the encoder input tokens'

(Note, I extended mine from Spanish only to French, so my encoders are named "fr..." instead of "spa")

def decode_sequences(input_sentences):
    batch_size = 1

    # Tokenize the encoder input.
    encoder_input_tokens = ops.convert_to_tensor(eng_tokenizer(input_sentences))
    if len(encoder_input_tokens[0]) < MAX_SEQUENCE_LENGTH:
        pads = ops.full((1, MAX_SEQUENCE_LENGTH - len(encoder_input_tokens[0])), 0)
        #encoder_input_tokens = ops.concatenate([encoder_input_tokens.to_tensor(), pads], 1) # <-- Original
        encoder_input_tokens = ops.concatenate([encoder_input_tokens.to_tensor(), pads], 1) # <-- Add ".to_tensor()" at the base of this

    # Define a function that outputs the next token's probability given the
    # input sequence.
    def next(prompt, cache, index):
        logits = transformer([encoder_input_tokens, prompt])[:, index - 1, :]
        # Ignore hidden states for now; only needed for contrastive search.
        hidden_states = None
        return logits, hidden_states, cache

    # Build a prompt of length 40 with a start token and padding tokens.
    length = 40
    start = ops.full((batch_size, 1), fr_tokenizer.token_to_id("[START]"))
    pad = ops.full((batch_size, length - 1), fr_tokenizer.token_to_id("[PAD]"))
    prompt = ops.concatenate((start, pad), axis=-1)

    generated_tokens = nlp.samplers.GreedySampler()(
        # end_token_id = fr_tokenizer.token_to_id("[END]") #<-- Original
        stop_token_ids=[fr_tokenizer.token_to_id("[END]")], #<-- Change argument name and wrap in list
        index=1,  # Start sampling after start token.
    generated_sentences = fr_tokenizer.detokenize(generated_tokens)
    return generated_sentences
github-actions[bot] commented 1 month ago

Are you satisfied with the resolution of your issue? Yes No