facebookresearch / seamless_communication

Foundational Models for State-of-the-Art Speech and Text Translation
Other
10.86k stars 1.06k forks source link

Poor output for Japanese S2TT tasks, and `mps` device #303

Open bcherny opened 10 months ago

bcherny commented 10 months ago

I have tried the following S2TT tasks:

# Device Input language Output language Result
1 cpu Russian English ✅ Correct output
2 mps Russian English ❌ "I don't know what you're talking about"
3 cpu Japanese English ❌ "I'm going to tell you a little bit about it, and I'm going to tell you a little bit about it."
4 mps Japanese English ❌ "I don't know what I'm going to do."

Output is mostly incorrect, and when it is incorrect, it seems fully hallucinated and unrelated to the input audio. I am new to Pytorch and this library -- is there a set of debugging steps I should follow to figure out what's causing the low quality results?

My machine

M2 Macbook Air

Inputs tried

My code

import torch

from seamless_communication.inference import Translator

translator = Translator(
    "seamlessM4T_v2_large",
    "vocoder_v2",
    device=torch.device("mps"), # works if I use device="cpu" + dtype=torch.float16
    dtype=torch.float32,
)

text_output, _ = translator.predict(
    input="rus.wav",
    src_lang="rus", # or, "jpn"
    task_str="s2tt",
    tgt_lang="eng",
)

print(f"Translated text to English: {text_output[0]}")

Output

Using the cached checkpoint of seamlessM4T_v2_large. Set `force` to `True` to download again.
Using the cached tokenizer of seamlessM4T_v2_large. Set `force` to `True` to download again.
Using the cached tokenizer of seamlessM4T_v2_large. Set `force` to `True` to download again.
Using the cached tokenizer of seamlessM4T_v2_large. Set `force` to `True` to download again.
Using the cached checkpoint of vocoder_v2. Set `force` to `True` to download again.
/opt/homebrew/lib/python3.11/site-packages/fairseq2/generation/beam_search.py:259: UserWarning: MPS: no support for int64 min/max ops, casting it to int32 (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/mps/operations/ReduceOps.mm:1271.)
  max_source_len = int(source_padding_mask.seq_lens.max())
Translated text to English: I don't know what you're talking about.
Niraj-Kamdar commented 9 months ago

I had same issue! the output isn't correct for MPS, float32 and yeah it works for CPU, float16