yl4579 / StyleTTS2

StyleTTS 2: Towards Human-Level Text-to-Speech through Style Diffusion and Adversarial Training with Large Speech Language Models
MIT License
4.98k stars 422 forks source link

In training Stage1 after 49th epoch getting RuntimeError: you can only change requires_grad flags of leaf variables, g_loss.requires_grad = True #258

Open SandyPanda-MLDL opened 5 months ago

SandyPanda-MLDL commented 5 months ago

I am getting the mentioned error in this part of the code: if epoch >= TMA_epoch: # start TMA training loss_s2s = 0 for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths): loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length]) loss_s2s /= texts.size(0)

            loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10

            loss_gen_all = gl(wav.detach().unsqueeze(1).float(), y_rec).mean()
            print(f'the shape of both wav and y_rec respectively {wav.shape} and {y_rec.shape}')
            loss_slm = wl(wav.detach(),  y_rec.squeeze(1)).mean()

            g_loss = loss_params.lambda_mel * loss_mel + \
            loss_params.lambda_mono * loss_mono + \
            loss_params.lambda_s2s * loss_s2s + \
            loss_params.lambda_gen * loss_gen_all + \
            loss_params.lambda_slm * loss_slm
            print(f'Generator loss is {g_loss}')

        running_loss += accelerator.gather(loss_mel).mean().item()
        #print(f"g-loss is {type(g_loss)}")
        optimizer.zero_grad()
        g_loss.requires_grad = True
        g_loss.backward()
        #accelerator.backward(g_loss)
        optimizer.step()
        # g_loss.requires_grad = True
        # g_loss.backward()
Dforgeek commented 3 months ago

I have the same issue, but for this code snippet:

d_loss = self._dl(wav.detach().unsqueeze(1).float(), y_rec.detach()).mean() Getting into details, it's the error in the forward method, in WavLMLoss class:

    def forward(self, wav, y_rec):

        with torch.no_grad():
            wav_16 = self.resample(wav)
            wav_embeddings = self.wavlm(input_values=wav_16, output_hidden_states=True).hidden_states
        y_rec_16 = self.resample(y_rec)
        y_rec_embeddings = self.wavlm(input_values=y_rec_16.squeeze(), output_hidden_states=True)
        y_rec_embeddings = y_rec_embeddings.hidden_states

        floss = 0
        for er, eg in zip(wav_embeddings, y_rec_embeddings):
            floss += torch.mean(torch.abs(er - eg))

        return floss.mean()

self.wavlm(input_values=y_rec_16.squeeze(), output_hidden_states=True) is giving me the exact same error and I don't know why.

What is your dependencies versions for this project?

Dforgeek commented 2 months ago

Found the solution. you need to write self.wavlm.eval() in the start of forward method WavLMLoss class in losses module. Worked for me.