Closed sevagh closed 3 years ago
Hey, thanks for checking out auraloss!
I have seen this issue before, I traced it to an instability issue in the computation of SI-SDR. My guess is that this is caused when the target contains values of all zeros. Here is a minimal code example to demonstrate that:
import torch
import auraloss
x = torch.rand(4, 2, 100)
y = torch.zeros(4, 2, 100)
sisdr_loss = auraloss.time.SISDRLoss()
print(sisdr_loss(x, y)) # nan
print(sisdr_loss(y, x)) # 80 dB
When the target y
is a vector of zeros we get a NaN for the loss. We would expect the loss to be symmetric, so setting x or y as the target should produce the same 80dB error (The error is bounded to 80dB due to the choice of eps=1e-8
).
Ideally you would not have a vector of all zeros when training, but the code ought to handle this situation. I have made a quick fix for this. You can try it by installing the latest auraloss version through the GitHub repo.
pip install git+https://github.com/csteinmetz1/auraloss.git
Let me know if this works for you. I will push a new release shortly afterwards.
Thanks for the reply - I'll be able to test this within a few days (waiting for a replacement GPU, currently).
I'm definitely having a better time with it @csteinmetz1 :
I installed the latest from github like you suggested.
Great! Thanks for testing it out.
Hello, I'm trying to use SI-SDR and/or SD-SDR loss for a model for music source separation.
I'm working with the well-known open-unmix model (https://github.com/sigsep/open-unmix-pytorch).
The original model is as follows:
I wanted to see if I could see any differences by using SDR (which is actually the real evaluation metric of the full source separation task) within the training loop:
In the first iteration, the network's prediction is so bad that in a few places, the SDR value is NaN. After this, the gradients get in a bad state and the next prediction from the network is entirely nan.
Here's some print statements I inserted in the body of the SI-SDR code to help pinpoint the issue. What's being printed is:
The output below is for the first 2 epochs, showing how it goes from 2 nans in the SI-SDR loss function to all nans.
Do you have any suggestions? I tried "torch.nan_to_num" (without any arguments, so using the default substitutions: https://pytorch.org/docs/stable/generated/torch.nan_to_num.html), and it's basically the exact same behavior (except the loss is 0 instead of nan).