facebookresearch / fairseq2

FAIR Sequence Modeling Toolkit 2
https://facebookresearch.github.io/fairseq2/
MIT License
682 stars 78 forks source link

Calling `SequenceToTextGeneratorBase._do_generate` with empty data throws IndexError #55

Closed znnahiyan closed 9 months ago

znnahiyan commented 1 year ago

Describe the bug: When the base class SequenceToTextGeneratorBase processes empty data in its method _do_generate(self, source_seqs: Tensor, source_seq_lens: Optional[Tensor]), at line 80, the code generates an IndexError:

https://github.com/facebookresearch/fairseq2/blob/0fc5de723979f4ff5f793a7c41292327e21141df/src/fairseq2/generation/text.py#L67-L84

Which throws:

[/usr/local/lib/python3.10/dist-packages/fairseq2/generation/text.py](https://localhost:8080/#) in <listcomp>(.0)
     78 
     79         # TODO: use parallel_invoke
---> 80         sentences = [self.token_decoder(b[0].seq)[0] for b in gen_output.results]
     81 
     82         return SequenceToTextOutput(

IndexError: list index out of range

Describe how to reproduce:

import torch
import torchaudio
from seamless_communication.models.inference import Translator

translator = Translator(
    model_name_or_card="seamlessM4T_medium",
    vocoder_name_or_card="vocoder_36langs",
    device=torch.device('cuda:0')
)

# Generate a silent audio clip.
sample_rate = 16000
duration = 10
channels = 1
empty_waveform = torch.zeros((channels, duration*sample_rate))
torchaudio.save('empty_audio.wav', src=empty_waveform, sample_rate=sample_rate, format='wav')

# Perform speech-to-text translation.
text, _, _ = translator.predict(
    input="empty_audio.wav",
    task_str="S2TT",
    tgt_lang="ben",
    sample_rate=sample_rate
)
Full stack trace
--------------------------------------------------------------------------- IndexError Traceback (most recent call last) [](https://localhost:8080/#) in () 15 torchaudio.save('empty_audio.wav', src=empty_waveform, sample_rate=sample_rate, format='wav') 16 ---> 17 text, _, _ = translator.predict( 18 input="empty_audio.wav", 19 task_str="S2TT", 8 frames [/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py](https://localhost:8080/#) in decorate_context(*args, **kwargs) 113 def decorate_context(*args, **kwargs): 114 with ctx_factory(): --> 115 return func(*args, **kwargs) 116 117 return decorate_context [/usr/local/lib/python3.10/dist-packages/seamless_communication/models/inference/translator.py](https://localhost:8080/#) in predict(self, input, task_str, tgt_lang, src_lang, spkr, ngram_filtering, sample_rate, text_max_len_a, text_max_len_b, unit_max_len_a, unit_max_len_b) 223 src = self.collate(self.token_encoder(text)) 224 --> 225 result = self.get_prediction( 226 self.model, 227 self.text_tokenizer, [/usr/local/lib/python3.10/dist-packages/seamless_communication/models/inference/translator.py](https://localhost:8080/#) in get_prediction(cls, model, text_tokenizer, unit_tokenizer, src, input_modality, output_modality, tgt_lang, ngram_filtering, text_max_len_a, text_max_len_b, unit_max_len_a, unit_max_len_b) 139 unit_opts=unit_opts, 140 ) --> 141 return generator( 142 src["seqs"], 143 src["seq_lens"], [/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py](https://localhost:8080/#) in decorate_context(*args, **kwargs) 113 def decorate_context(*args, **kwargs): 114 with ctx_factory(): --> 115 return func(*args, **kwargs) 116 117 return decorate_context [/usr/local/lib/python3.10/dist-packages/seamless_communication/models/unity/generator.py](https://localhost:8080/#) in __call__(self, source_seqs, source_seq_lens, input_modality, output_modality, ngram_filtering) 171 172 if input_modality == "speech": --> 173 text_output = self.s2t_generator.generate_ex(source_seqs, source_seq_lens) 174 elif input_modality == "text" and self.t2t_generator is not None: 175 text_output = self.t2t_generator.generate_ex(source_seqs, source_seq_lens) [/usr/local/lib/python3.10/dist-packages/fairseq2/generation/text.py](https://localhost:8080/#) in generate_ex(self, source_seqs, source_seq_lens) 153 :math:`N` is the batch size. 154 """ --> 155 return self._do_generate(source_seqs, source_seq_lens) 156 157 [/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py](https://localhost:8080/#) in decorate_context(*args, **kwargs) 113 def decorate_context(*args, **kwargs): 114 with ctx_factory(): --> 115 return func(*args, **kwargs) 116 117 return decorate_context [/usr/local/lib/python3.10/dist-packages/fairseq2/generation/text.py](https://localhost:8080/#) in _do_generate(self, source_seqs, source_seq_lens) 78 79 # TODO: use parallel_invoke ---> 80 sentences = [self.token_decoder(b[0].seq)[0] for b in gen_output.results] 81 82 return SequenceToTextOutput( [/usr/local/lib/python3.10/dist-packages/fairseq2/generation/text.py](https://localhost:8080/#) in (.0) 78 79 # TODO: use parallel_invoke ---> 80 sentences = [self.token_decoder(b[0].seq)[0] for b in gen_output.results] 81 82 return SequenceToTextOutput( IndexError: list index out of range

Describe the expected behavior: There shouldn't be any IndexError exceptions thrown by SequenceToTextGeneratorBase._do_generate. Instead, it should just return normally with an empty result.

Environment: Dependencies:

Additional Context: I'm not familiar with the library so I'm afraid to say that there is most definitely a much simpler code example than what I've given above.

cbalioglu commented 1 year ago

Thanks for the bug report @znnahiyan! Let me check it out this week.

cbalioglu commented 9 months ago

Fixed in the latest SequenceGenerator API implementation. Thanks again for reporting it!