csteinmetz1 / auraloss

Collection of audio-focused loss functions in PyTorch
Apache License 2.0
695 stars 66 forks source link

Adding perceptual weighting to `SumAndDifferenceSTFTLoss` #48

Closed csteinmetz1 closed 1 year ago

csteinmetz1 commented 1 year ago

This should enable the use of simple A-weighting as a pre-filtering process before computing the sum and difference signals.

Example usage:

target = torch.rand(8, 2, 44100)
pred = torch.rand(8, 2, 44100)
loss_fn = auraloss.freq.SumAndDifferenceSTFTLoss(
      fft_sizes=[1024, 2048, 8192],
      hop_sizes=[256, 512, 2048],
      win_lengths=[1024, 2048, 8192],
      perceptual_weighting=True,
      sample_rate=44100,
      scale="mel",
      n_bins=128"
)
res = loss_fn(pred, target)

Notes:

csteinmetz1 commented 1 year ago

If you are using mel or perceptual_weighting you will need to move the loss function to the save device as the model. Need to make a note of this in the README with some examples.