csteinmetz1 / auraloss

Collection of audio-focused loss functions in PyTorch
Apache License 2.0
748 stars 67 forks source link

Negative Losses for SNR and SISNR #13

Closed LearnedVector closed 3 years ago

LearnedVector commented 3 years ago

Hello, great library! I was wondering if it's normal to have negative losses for SNR and SISNR? I'm exploring multiple losses for my network but both of these losses are outputting negative losses. Thanks!

csteinmetz1 commented 3 years ago

Yes, it is normal for these losses to output negative values. In the case of the SNR-based metrics, a higher value is better, therefore, since we are using them as loss functions, we aim to minimize the negative of the metrics. In auraloss, we directly apply this negative in the loss function so you can plug the output from the loss right into your optimizer. I do understand this can be somewhat confusing, so we have made note of it in the docstrings for relevant functions.

Thanks for your interest in auraloss!

LearnedVector commented 3 years ago

@csteinmetz1 thanks for the explanation! In this case, how would we use snr loss in combination with the other losses like l1 or multiresolutionSTFT? Thanks for making this great library!

csteinmetz1 commented 3 years ago

It should just be a matter of computing those losses on your predictions individually and then summing them. You might also want to consider applying some weighting to those losses. For example, here is some pseudocode:

# create the losses
snr = auraloss.time.SNRLoss()
mrstft = auraloss.freq.MultiResolutionSTFTLoss()

# make your predictions
pred = model(input)

# compute each loss and sum
snr_loss = snr(pred, target)
mrstft_loss = mrstft(pred, target)
loss = snr_loss + mrstft_loss

loss.backward()
optimizer.step()
LearnedVector commented 3 years ago

@csteinmetz1 thank you, this was super helpful!