keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
758 stars 227 forks source link

Machine Translation With Transformers #1280

Open alerem18 opened 10 months ago

alerem18 commented 10 months ago

can't run this example on jax or pytorch backend it just works on tensorflow backend

https://keras.io/examples/nlp/neural_machine_translation_with_keras_nlp/

also inferencing is significantly slower than a similar implementation in pytorch, like 8 times slower

mattdangerw commented 10 months ago

Thanks! We will take a look.

shivance commented 10 months ago

1189 for context.

alerem18 commented 10 months ago

@shivance and how's the decoding part?

def decode_sequences(input_sentences):
    batch_size = tf.shape(input_sentences)[0]

    #print(tf.strings.unicode_split(input_sentences, 'UTF-8'))
     # Tokenize the encoder input.
    encoder_input_tokens = input_packer(tokenizer(input_sentences))

    # Define a function that outputs the next token's probability given the
    # input sequence.

    def next(prompt, cache, index):
        logits = transformer([encoder_input_tokens, prompt])[:, index - 1, :]
        # Ignore hidden states for now; only needed for contrastive search.
        hidden_states = None
        return logits, hidden_states, cache

    # Build a prompt of length 40 with a start token and padding tokens.
    length = TARGET_MAX_SEQUENCE_LENGTH
    start = tf.fill((batch_size, 1), START_VALUE)
    pad = tf.fill((batch_size, length - 1), PAD_VALUE)
    prompt = tf.concat((start, pad), axis=-1)

    generated_tokens = keras_nlp.samplers.GreedySampler()(
        next,
        prompt,
        end_token_id=END_VALUE,
        index=1,  # Start sampling after start token.
    )

    generated_sentences = tokenizer.detokenize(generated_tokens)

    return generated_sentences

test_eng_texts = [pair[0] for pair in test_pairs]
iter = tqdm(enumerate(test_pairs))
corrects = 0
for i, pair in iter:
    input_sentence = pair[0]
    target_sentence = pair[1]

    translated = decode_sequences(tf.constant([input_sentence]))
    translated = translated.numpy()[0].decode("utf-8")
    translated = (
        translated.replace(PAD_TOKEN, "")
        .replace(START_TOKEN, "")
        .replace(END_TOKEN, "")
        .replace(' ', '')
        .strip()
    )

    if translated == target_sentence:
        corrects += 1

    iter.set_postfix(corrects=corrects, accuracy=corrects / (i + 1))
    print(f"** Example {i} **")
    print(input_sentence)
    print(translated)
    print()