csteinmetz1 / auraloss

Collection of audio-focused loss functions in PyTorch
Apache License 2.0
717 stars 66 forks source link

SDR loss sensitive to nan #25

Closed sevagh closed 3 years ago

sevagh commented 3 years ago

Hello, I'm trying to use SI-SDR and/or SD-SDR loss for a model for music source separation.

I'm working with the well-known open-unmix model (https://github.com/sigsep/open-unmix-pytorch).

The original model is as follows:

  1. Xmag = magnitude spectrogram of input waveform (mixed song)
  2. Ymag_hat = network prediction of magnitude spectrogram of the source being separated (one of drums, bass, vocals, other)
  3. Ymag = magnitude spectrogram of ground truth of the source being separated
  4. Loss = MSE(Ymag_hat, Ymag)

I wanted to see if I could see any differences by using SDR (which is actually the real evaluation metric of the full source separation task) within the training loop:

  1. X = complex spectrogram of input waveform (mixed song)
  2. Xmag = magnitude(X)
  3. Xphase = phase(X)
  4. Ymag_hat = network prediction of magnitude spectrogram of source being separated
  5. Ycomplex_hat = Ymag_hat * Xphase (combine source magnitude + mix phase for source complex spectrogram)
  6. y_hat = istft(Ycomplex_hat)
  7. Loss = auraloss.SISDR(y_hat, y), loss on SDR of waveforms

In the first iteration, the network's prediction is so bad that in a few places, the SDR value is NaN. After this, the gradients get in a bad state and the next prediction from the network is entirely nan.

Here's some print statements I inserted in the body of the SI-SDR code to help pinpoint the issue. What's being printed is:

  1. Input tensor (waveform)
  2. Output tensor (waveform from the neural network's predicted spectrogram)
  3. SI-SDR loss functions (printing each intermediate step before the final value)

The output below is for the first 2 epochs, showing how it goes from 2 nans in the SI-SDR loss function to all nans.

(umx-gpu) sevagh:umx-mr $ ./train.sh
Using GPU: True
Training Epoch:   0%|                                                                                                                                                                                                                 | 0/1000 [00:00<?, ?it/sy_hat, y shape: torch.Size([16, 2, 264600]), torch.Size([16, 2, 264600])                                                                                                                                                                | 0/344 [00:00<?, ?it/s]
tensor([[[-2.1781e-03, -6.5221e-04,  1.1178e-03,  ...,  1.3619e-03,
           2.1249e-03,  1.0944e-02],
         [-1.2133e-02, -8.7155e-03, -6.3351e-03,  ...,  1.4447e-02,
           2.1405e-02,  2.3420e-02]],

        [[ 2.0096e-01,  1.8979e-01,  2.0478e-01,  ..., -1.1957e-02,
          -4.6632e-03, -1.6420e-03],
         [ 1.4364e-01,  1.3714e-01,  1.5252e-01,  ..., -2.3354e-02,
          -7.9426e-03,  4.1423e-03]],

        [[-1.7183e-02, -9.4009e-03,  9.0017e-05,  ..., -4.4771e-02,
          -4.8189e-02, -5.0478e-02],
         [-1.9377e-02, -1.3426e-02,  1.5882e-03,  ..., -2.9326e-02,
          -2.8929e-02, -2.5298e-02]],

        ...,

        [[-3.7857e-03, -4.0909e-03, -3.7857e-03,  ..., -9.3050e-05,
           3.6471e-04,  7.9196e-04],
         [-1.6300e-03, -1.5385e-03, -1.9657e-03,  ...,  1.0948e-04,
          -5.0087e-04, -1.6518e-04]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[ 1.0614e-01,  9.7930e-02,  8.7920e-02,  ..., -1.0565e-01,
          -4.4679e-02,  3.5308e-02],
         [ 1.0367e-01,  9.5126e-02,  8.4048e-02,  ..., -1.2213e-01,
          -6.5000e-02,  1.4834e-02]]], device='cuda:0')

tensor([[[ 0.1386,  0.1427,  0.1455,  ...,  0.1328,  0.0626,  0.0005],
         [ 0.0152,  0.0212,  0.0245,  ...,  0.1702,  0.1278,  0.0885]],

        [[ 0.3065,  0.2870,  0.2915,  ..., -0.1038, -0.0997, -0.0963],
         [ 0.1356,  0.1494,  0.1805,  ..., -0.1743, -0.1608, -0.1511]],

        [[ 0.0508,  0.0478,  0.0527,  ...,  0.2344,  0.2701,  0.2726],
         [-0.1404, -0.1422, -0.1259,  ...,  0.3193,  0.3360,  0.3355]],

        ...,

        [[-0.0528, -0.0437, -0.0349,  ..., -0.1331, -0.1366, -0.1403],
         [-0.2515, -0.2280, -0.2245,  ...,  0.0535,  0.0505,  0.0503]],

        [[-0.0694, -0.0605, -0.0499,  ..., -0.4447, -0.4446, -0.4433],
         [ 0.0713,  0.0818,  0.0828,  ..., -0.0331, -0.0137,  0.0073]],

        [[ 0.3649,  0.3720,  0.3841,  ..., -0.1016, -0.0735, -0.0239],
         [ 0.3345,  0.3408,  0.3497,  ..., -0.1535, -0.0997, -0.0270]]],
       device='cuda:0', grad_fn=<SubBackward0>)

losses: tensor([[-13.2343, -12.9311],
        [  1.0036,  -0.8949],
        [ -6.3418,  -7.2471],
        [-53.9541, -59.8263],
        [  3.6331,   1.7483],
        [-21.9817, -20.3990],
        [     nan,      nan],
        [ -6.9908,  -7.9694],
        [ -2.5666,  -4.8039],
        [-35.9928, -35.8143],
        [-18.4182, -16.2987],
        [ -4.5252,  -9.3269],
        [ -4.0607,  -5.2343],
        [ -3.2994,  -1.3853],
        [     nan,      nan],
        [ -2.2624,   0.1392]], device='cuda:0', grad_fn=<MulBackward0>)
losses: -10.913591384887695
loss: 10.913591384887695
                                                                                                                                                                                                                                                              y_hat, y shape: torch.Size([16, 2, 264600]), torch.Size([16, 2, 264600])                                                                                                                                                        | 1/344 [00:02<14:28,  2.53s/it]
tensor([[[ 0.0071, -0.0640, -0.0851,  ..., -0.0198, -0.0260, -0.0203],
         [ 0.0515, -0.0161, -0.0546,  ...,  0.0040,  0.0071,  0.0099]],

        [[ 0.0042,  0.0053,  0.0014,  ...,  0.0006,  0.0010,  0.0017],
         [-0.0004, -0.0029, -0.0021,  ..., -0.0024, -0.0026, -0.0041]],

        [[-0.0011, -0.0011, -0.0011,  ...,  0.0019,  0.0021,  0.0017],
         [-0.0004, -0.0005, -0.0005,  ...,  0.0010,  0.0012,  0.0012]],

        ...,

        [[ 0.1997,  0.1637,  0.1673,  ..., -0.0564, -0.0582, -0.0581],
         [ 0.2007,  0.1652,  0.1685,  ..., -0.0525, -0.0540, -0.0547]],

        [[ 0.0412,  0.0584,  0.0503,  ..., -0.0031,  0.0207,  0.0357],
         [ 0.0418,  0.0590,  0.0510,  ...,  0.0010,  0.0250,  0.0392]],

        [[ 0.0338,  0.0072, -0.0184,  ..., -0.0112, -0.0148, -0.0131],
         [ 0.0114,  0.0253,  0.0207,  ..., -0.0008, -0.0031, -0.0061]]],
       device='cuda:0')

tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        ...,

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]]], device='cuda:0',
       grad_fn=<SubBackward0>)

losses: tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan]], device='cuda:0', grad_fn=<MulBackward0>)
losses: nan
loss: nan

Do you have any suggestions? I tried "torch.nan_to_num" (without any arguments, so using the default substitutions: https://pytorch.org/docs/stable/generated/torch.nan_to_num.html), and it's basically the exact same behavior (except the loss is 0 instead of nan).

csteinmetz1 commented 3 years ago

Hey, thanks for checking out auraloss!

I have seen this issue before, I traced it to an instability issue in the computation of SI-SDR. My guess is that this is caused when the target contains values of all zeros. Here is a minimal code example to demonstrate that:

import torch
import auraloss

x = torch.rand(4, 2, 100)
y = torch.zeros(4, 2, 100)

sisdr_loss = auraloss.time.SISDRLoss()

print(sisdr_loss(x, y))  # nan
print(sisdr_loss(y, x))  # 80 dB

When the target y is a vector of zeros we get a NaN for the loss. We would expect the loss to be symmetric, so setting x or y as the target should produce the same 80dB error (The error is bounded to 80dB due to the choice of eps=1e-8).

Ideally you would not have a vector of all zeros when training, but the code ought to handle this situation. I have made a quick fix for this. You can try it by installing the latest auraloss version through the GitHub repo.

pip install git+https://github.com/csteinmetz1/auraloss.git

Let me know if this works for you. I will push a new release shortly afterwards.

sevagh commented 3 years ago

Thanks for the reply - I'll be able to test this within a few days (waiting for a replacement GPU, currently).

sevagh commented 3 years ago

I'm definitely having a better time with it @csteinmetz1 : image

I installed the latest from github like you suggested.

csteinmetz1 commented 3 years ago

Great! Thanks for testing it out.