pytorch / audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
https://pytorch.org/audio
BSD 2-Clause "Simplified" License
2.5k stars 644 forks source link

torcaudio.functional.rnnt_loss input_length mismatch #2834

Open mohame54 opened 1 year ago

mohame54 commented 1 year ago

🐛 Describe the bug

def loss(self,audio_feat,feat_lens,target,target_lens): """ audio_feat: mel_spectrogram, feat_lens :mel_length before padding target: target_seq target_lens: target sequence length before padding """

      x,lens = self.encoder(audio_feat,feat_lens)
      y = self.decoder(target)
      joint_out = self.joint(x,y)
      loss = F.functional.rnnt_loss(logits=joint_out,targets=target,
                       logit_lengths=lens,target_lengths=target_lens,blank=self.null_id)
      return loss

Versions

I keep getting this error of input length mismatch and output length mismatch

hwangjeff commented 1 year ago

Can you provide the stack trace and reproduction steps?

mohame54 commented 1 year ago

I want to train a transducer model for speech recognition task first , I extract the mel spectrogram from audio signal considering the spectrogram length in my training then I encode label using hugging face tokenizers also considering the label lengths, The encoding goes like this first I encode the label then I prepend the null token and finally padding the labels

mohame54 commented 1 year ago

Can you provide the stack trace and reproduction steps?

I could show you my colab if you want

mohame54 commented 1 year ago

I'd like to note that my encoder network is a Conformer network that do conv sampling before applying the attention mechanism hence I reduce the time axis or the audio feature length during training to avoid memory crashes