Machine Translation With Transformers #1280

Open alerem18 opened 10 months ago

alerem18 commented 10 months ago

can't run this example on jax or pytorch backend it just works on tensorflow backend

also inferencing is significantly slower than a similar implementation in pytorch, like 8 times slower

mattdangerw commented 10 months ago

Thanks! We will take a look.

shivance commented 10 months ago

1189 for context.

alerem18 commented 10 months ago

@shivance and how's the decoding part?

def decode_sequences(input_sentences):
    batch_size = tf.shape(input_sentences)[0]

    #print(tf.strings.unicode_split(input_sentences, 'UTF-8'))
     # Tokenize the encoder input.
    encoder_input_tokens = input_packer(tokenizer(input_sentences))

    # Define a function that outputs the next token's probability given the
    # input sequence.

    def next(prompt, cache, index):
        logits = transformer([encoder_input_tokens, prompt])[:, index - 1, :]
        # Ignore hidden states for now; only needed for contrastive search.
        hidden_states = None
        return logits, hidden_states, cache

    # Build a prompt of length 40 with a start token and padding tokens.
    start = tf.fill((batch_size, 1), START_VALUE)
    pad = tf.fill((batch_size, length - 1), PAD_VALUE)
    prompt = tf.concat((start, pad), axis=-1)

    generated_tokens = keras_nlp.samplers.GreedySampler()(
        index=1,  # Start sampling after start token.

    generated_sentences = tokenizer.detokenize(generated_tokens)

    return generated_sentences

test_eng_texts = [pair[0] for pair in test_pairs]
iter = tqdm(enumerate(test_pairs))
corrects = 0
for i, pair in iter:
    input_sentence = pair[0]
    target_sentence = pair[1]

    translated = decode_sequences(tf.constant([input_sentence]))
    translated = translated.numpy()[0].decode("utf-8")
    translated = (
        translated.replace(PAD_TOKEN, "")
        .replace(START_TOKEN, "")
        .replace(END_TOKEN, "")
        .replace(' ', '')

    if translated == target_sentence:
        corrects += 1

    iter.set_postfix(corrects=corrects, accuracy=corrects / (i + 1))
    print(f"** Example {i} **")