csteinmetz1 / auraloss

Collection of audio-focused loss functions in PyTorch
Apache License 2.0
695 stars 66 forks source link

[cm] Fix DCLoss and refactor ESRLoss #54

Closed christhetree closed 1 year ago

christhetree commented 1 year ago

This PR refactors the DCLoss and ESRLoss to include an eps term to prevent divide by zero errors on silent signals. It also adds type hinting, removes unnecessary abs() calls, and breaks the logic down into numerator and denominator.

Finally, the current DCLoss denominator term is squared after the mean is taken, but the mathematical notation in the paper implies the opposite should be done. This agrees with the author of the paper's source code where the denominator is mean energy, but more confusingly, their implementation of the numerator is very different:

loss = tr.pow(tr.add(tr.mean(target, 0), -tr.mean(output, 0)), 2)
loss = tr.mean(loss)
energy = tr.mean(tr.pow(target, 2)) + self.epsilon
loss = tr.div(loss, energy)
return loss

I believe this PR should exhibit the correct behavior now, but curious what your thoughts are.

csteinmetz1 commented 1 year ago

I think the code block you shared makes sense even though it is written differently in the original paper equation. The loss is measuring the MSE between the DC component (mean) of both the target and output. Based on my reading we should take the mean before squaring in the numerator but not the denominator.

The other difference is whether you normalize by the energy of each batch item or the entire batch. In the code you shared it seems like energy is the aggregate energy across the whole batch. Is that right? I think the original idea would be to compute the energy across each batch item. However, in the paper I do not see mention of the batch dimension. So perhaps this code assumes batch size of 1?

What do you think about the following for the DCLoss?

num = ((target - input).mean(dim=-1) ** 2)
denom = (target ** 2).mean(dim=-1) + self.eps

Screen Shot 2023-03-10 at 09 52 20

christhetree commented 1 year ago

So on further examination, the reason the code block didn't make sense initially is because it was written for og PyTorch LSTMs, where the first dimension is the sequence length, not the batch size. If you look at it from this point of view, I believe the code ends up being the same. This should explain your question about the batch size. As is common with loss functions, I think everything should be done at a batch level of granularity, and then reduced in the final step.

"Based on my reading we should take the mean before squaring in the numerator but not the denominator." <- I agree with this. The code you proposed matches the code for DCLoss in this PR so I think we're on the same page there.

christhetree commented 1 year ago

My bad I realized I made a mistake in the numerator, will push a fix up now

csteinmetz1 commented 1 year ago

Looks good. Thanks.