Open SandyPanda-MLDL opened 5 months ago
I have the same issue, but for this code snippet:
d_loss = self._dl(wav.detach().unsqueeze(1).float(), y_rec.detach()).mean()
Getting into details, it's the error in the forward method, in WavLMLoss class:
def forward(self, wav, y_rec):
with torch.no_grad():
wav_16 = self.resample(wav)
wav_embeddings = self.wavlm(input_values=wav_16, output_hidden_states=True).hidden_states
y_rec_16 = self.resample(y_rec)
y_rec_embeddings = self.wavlm(input_values=y_rec_16.squeeze(), output_hidden_states=True)
y_rec_embeddings = y_rec_embeddings.hidden_states
floss = 0
for er, eg in zip(wav_embeddings, y_rec_embeddings):
floss += torch.mean(torch.abs(er - eg))
return floss.mean()
self.wavlm(input_values=y_rec_16.squeeze(), output_hidden_states=True)
is giving me the exact same error and I don't know why.
What is your dependencies versions for this project?
Found the solution. you need to write self.wavlm.eval()
in the start of forward method WavLMLoss class in losses module. Worked for me.
I am getting the mentioned error in this part of the code: if epoch >= TMA_epoch: # start TMA training loss_s2s = 0 for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths): loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length]) loss_s2s /= texts.size(0)