facebookresearch / seamless_communication

Foundational Models for State-of-the-Art Speech and Text Translation
Other
10.8k stars 1.05k forks source link

How to keep original word during translation #385

Open csrednicki opened 6 months ago

csrednicki commented 6 months ago

So I have sentence like This cat belongs to the <breed> domestic cat breed. Currently during translation word <breed> is removed.

Current translation: Ten kot należy do rasy kotów domowych. My expactation: Ten kot należy do rasy <breed> kotów domowych.

Is there some special tag/token that I can use to save original data in translated output?

This is my code:

import torch
from seamless_communication.inference import Translator

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    dtype = torch.float16
else:
    device = torch.device("cpu")
    dtype = torch.float32

translator = Translator(
    model_name_or_card="seamlessM4T_v2_large",
    vocoder_name_or_card=None,
    device=device,
    dtype=dtype,
    apply_mintox=False,
)

def run_t2tt(input_text: str, source_language: str, target_language: str) -> str:
    source_language_code = source_language
    target_language_code = target_language
    out_texts, _ = translator.predict(
        input=input_text,
        task_str="T2TT",
        duration_factor=1,
        src_lang=source_language_code,
        tgt_lang=target_language_code,
    )
    return str(out_texts[0])

text = "This cat belongs to the <breed> domestic cat breed."

translation = run_t2tt(text, "eng", "pol")

print(translation)
avidale commented 6 months ago

In the official Seamless implementation (based on fairseq2), there is currently no recommended way to force the model to include some particular word in the output.

In the Hugginface transformers implementation, though, you could use positively constrained beam search (https://huggingface.co/blog/constrained-beam-search) to enforce that.

import torch
from transformers import SeamlessM4TForTextToText, SeamlessM4TTokenizer

model = SeamlessM4TForTextToText.from_pretrained("facebook/hf-seamless-m4t-medium").to('cuda')
tokenizer = SeamlessM4TTokenizer.from_pretrained(
    "facebook/hf-seamless-m4t-medium", src_lang="eng", tgt_lang="pol"
)

# Default translation does not produce the special word (because the model was never trained this way!)
text = "This cat belongs to the <breed> domestic cat breed."
inputs = tokenizer(text, return_tensors='pt').to(model.device)
output_tokens = model.generate(**inputs, num_beams=5, tgt_lang="pol")
print(tokenizer.decode(output_tokens[0], skip_special_tokens=True))
# Ten kot należy do rasy kotów domowych.

# However, we can still force this word
force_words_ids = tokenizer(['<breed>'], add_special_tokens=False).input_ids
print(force_words_ids)
# [[45, 9653, 76, 248123]]
output_tokens = model.generate(**inputs, num_beams=5, tgt_lang="pol", force_words_ids=[force_words_ids])
print(tokenizer.decode(output_tokens[0], skip_special_tokens=True))
# Ten kot należy do rasy kotów domowych <breed>.

I don't speak Polish so I am not sure whether putting <breed> to the end is grammatical. But a phrase like Ten kot należy do rasy kotów domowych Maine Coon. intuitively seems reasonable.

Please note that force_words_ids may work inadequately if num_beams is smaller or equal than the number of tokens in the forced word.