bfs18 / rfwave

MIT License
78 stars 4 forks source link

About the STFT Loss #4

Open sh-lee-prml opened 1 month ago

sh-lee-prml commented 1 month ago

Hi Thanks for nice work!

I have a question about the STFT Loss.

Previously, I have tried to directly adopt the STFT loss on the estimated vector field, and this decrease the performance.

However, I found you utilized the STFT on the (The estimated Vector Field + X0) so this part is very interesting to me.

The question is

Have you compared the STFT loss on the estimated vector filed directly?

If you did, please share your experience!

Thanks for nice work again!

bfs18 commented 1 month ago

Hi @sh-lee-prml, Thank you for your interest in this matter.

In the context of a Rectified Flow formulation, adding the estimated vector field to X0 provides an estimation of X1. Therefore, it is indeed sensible to apply the STFT loss to this sum. Directly applying STFT loss to the vector field alone is not as sound, but given that |STFT(x)| is a fixed function, the STFT loss on vector field can be viewed as a way to match a transformed feature, I guess it does not necessarily degrade the results when weighted appropriately.

To find a proper weight for the STFT loss, one effective approach is to examine the gradient norms produced by different loss terms. In my practice, I adjust the weight for the STFT loss so that the gradient norm of the STFT loss (g_stft) is approximately one-tenth that of the Rectified Flow loss (g_rf). The following code can be used to compute these gradient norms:

g_stft = torch.norm(torch.stack([torch.norm(g) for g in torch.autograd.grad(loss_stft, model.parameters(), retain_graph=True) if g is not None]))
g_rf = torch.norm(torch.stack([torch.norm(g) for g in torch.autograd.grad(loss_rf, model.parameters(), retain_graph=True) if g is not None]))

This ensures that the STFT loss contributes to the overall learning process without overwhelming the primary loss function.

sh-lee-prml commented 1 month ago

Thanks for the reply!

Now, I've tried to use the STFT loss with the weight of 1, 0.1, 0.01.

and thanks for sharing your experience. I will check the gradient norm following your suggestion!

I could share my results after training the model with STFT loss.

Thanks!