huggingface / transformers

πŸ€— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.76k stars 27.18k forks source link

πŸš€ Faster batch translation with FSMT model #9994

Closed itssimon closed 3 years ago

itssimon commented 3 years ago

πŸš€ Faster batch translation with FSMT model

Currently, generating translations for multiple inputs at once is very slow using Transformers' FSMTForConditionalGeneration implementation. In fact it's about 10x slower than using the original FairSeq library. Can we speed this up by improving the implementation, potentially leaning on the original FairSeq approach?

Motivation

I'm using FairSeq models for back translation as a way to augment text data. I've implemented this using the original FairSeq model (from PyTorch Hub) and Transformers.

FairSeq implementation

import torch

en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model', tokenizer='moses', bpe='fastbpe').cuda()
de2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.de-en.single_model', tokenizer='moses', bpe='fastbpe').cuda()

def back_translate_fairseq(texts: List[str]) -> List[List[str]]:
    tokenized_texts = [en2de.encode(text) for text in texts]
    back_translations = [set() for _ in range(len(texts))]

    # Translate texts to German
    tokenized_de_texts = [
        [output['tokens'].cpu() for output in batch_output]
        for batch_output in en2de.generate(tokenized_texts, beam=2, sampling=True, sampling_topp=0.7)
    ]
    tokenized_de_texts_flat = [t for tt in tokenized_de_texts for t in tt]

    # Translate back to English
    tokenized_en_texts = [
        [output['tokens'].cpu() for output in batch_output]
        for batch_output in de2en.generate(tokenized_de_texts_flat, beam=2, sampling=True, sampling_topp=0.8)
    ]
    tokenized_en_texts_flat = [t for tt in tokenized_en_texts for t in tt]

    # Decode and deduplicate back-translations and assign to original text indices
    for i, t in enumerate(tokenized_en_texts_flat):
        back_translations[i // 4].add(de2en.decode(t).lower())

    # Remove back translations that are equal to the original text
    return [[bt for bt in s if bt != t] for s, t in zip(back_translations, map(str.lower, texts))]

Transformers implementation

from transformers import FSMTForConditionalGeneration, FSMTTokenizer

en2de_model_name = "facebook/wmt19-en-de"
en2de_tokenizer = FSMTTokenizer.from_pretrained(en2de_model_name)
en2de_model = FSMTForConditionalGeneration.from_pretrained(en2de_model_name)

de2en_model_name = "facebook/wmt19-de-en"
de2en_tokenizer = FSMTTokenizer.from_pretrained(de2en_model_name)
de2en_model = FSMTForConditionalGeneration.from_pretrained(de2en_model_name)

def back_translate_transformers(texts: List[str]) -> List[List[str]]:
    tokenized_texts = en2de_tokenizer.prepare_seq2seq_batch(texts, return_tensors="pt")
    back_translations = [set() for _ in range(len(texts))]

    # Translate texts to German and back to English
    generate_kwargs = {"num_beams": 1, "do_sample": True, "num_return_sequences": 2}
    tokenized_de_texts = en2de_model.generate(tokenized_texts["input_ids"], attention_mask=tokenized_texts["attention_mask"], top_p=0.7, **generate_kwargs)
    tokenized_en_texts = de2en_model.generate(tokenized_de_texts, top_p=0.8, **generate_kwargs)

    # Decode and deduplicate back-translations and assign to original text indices
    for i, t in enumerate(tokenized_en_texts):
        back_translations[i // 4].add(de2en_tokenizer.decode(t, skip_special_tokens=True).lower())

    # Remove back translations that are empty or equal to the original text
    return [[bt for bt in s if bt and bt != t] for s, t in zip(back_translations, map(str.lower, texts))]

Both of these functions generate comparable results, but using Transformers it takes about 10x longer.

In my use case I need back translations for hundreds of thousands of text snippets, which unfortunately makes the Transformers implementation unfeasible. I'd love to use Transformers though, as it is much easier to install and deploy (as we use Transformers for text classification anyway).

patil-suraj commented 3 years ago

Hey @itssimon

From a quick look at your code, it seems that the fairseq model is on GPU, but the transformers model is on CPU, which could explain the huge speed difference. Could you try running it on GPU ?

itssimon commented 3 years ago

Oh dear, how embarassing. That's it! Thanks!