csteinmetz1 / auraloss

Collection of audio-focused loss functions in PyTorch
Apache License 2.0
689 stars 67 forks source link

reduction=none broken #69

Open turian opened 6 months ago

turian commented 6 months ago
import torch
import auraloss

mrstft = auraloss.freq.MultiResolutionSTFTLoss(reduction="none")

input = torch.rand(8,1,44100)
target = torch.rand(8,1,44100)

loss = mrstft(input, target)
print(loss)
print(loss.shape)

gives

Traceback (most recent call last):
  File "/private/tmp/testaura.py", line 9, in <module>
    loss = mrstft(input, target)
  File "/opt/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/miniconda3/lib/python3.9/site-packages/auraloss/freq.py", line 410, in forward
    mrstft_loss += f(x, y)
RuntimeError: The size of tensor a (368) must match the size of tensor b (184) at non-singleton dimension 2

What I'm really looking for is a per-instance reduction, which I can compute from reduction=none. Anyway, reduction=none is not working out of the box, which is unfortunately a showstopper for me :(

mcherep commented 4 months ago

I have the same issue

csteinmetz1 commented 4 months ago

The reason this currently doesn't work is that the desired behavior for reduction="none" when using the multi-resolution loss is a bit ambiguous. Since each STFTLoss will produce a different shape both in the frequency axis and the time axis it is not possible to combine them into a single loss tensor, as is the normal behavior.

We could consider returning an list of tensors that correspond to each STFTLoss output, but this would provide a different return type than the normal behavior which is just a single tensor. Would this address your applications @turian, @mcherep?

mcherep commented 4 months ago

I think this would work for me, what I need is to have one aggregated loss per batch instead of aggregating over all dimensions. Let me know if I'm missing something!

csteinmetz1 commented 4 months ago

Perhaps the best way to achieve the desired behavior right now is to manage a set of STFTLoss instances yourself. That way you can define the behavior of how they are aggregated, in this case without any. I am hesitant to add more complexity to the possible return types of MultiResolutionSTFTLoss class if we can avoid it. This class is really just a wrapper around STFTLoss for convenience. Here is a potential example.

import torch
import auraloss

fft_sizes = [512, 1024, 2048]
win_lengths = [512, 1024, 2048]
hop_sizes = [256, 512, 1024]
reduction = "none"

loss_fns = torch.nn.ModuleList()

for fft_size, win_length, hop_size in zip(fft_sizes, win_lengths, hop_sizes):
  loss_fns.append(auraloss.freq.STFTLoss(fft_size, hop_size, win_length, reduction=reduction))

bs = 4
chs = 1
seq_len = 131072

x = torch.randn(bs, chs, seq_len)
y = torch.randn(bs, chs, seq_len)

for loss_fn in loss_fns:
  loss = loss_fn(x, y)
  print(loss.shape)

outputs

torch.Size([4, 257, 513])
torch.Size([4, 513, 257])
torch.Size([4, 1025, 129])
mcherep commented 4 months ago

Great, I will do this instead. Thanks so much!

egaznep commented 2 months ago

I came up with another hacky way to do this while leveraging all the convenience of the MultiResolutionSTFT class is to monkey-patch the auraloss.freq.apply_reduction function in the following way:

auraloss.freq.apply_reduction = lambda losses, reduction: losses.mean(dim=(-1,-2)) (last two dims are STFT bins and frames, reduction averages over them, keeping every other dimension)

egaznep commented 2 months ago

@csteinmetz1 Also, I noticed that for spectral convergence loss, despite the option `reduction=None, we nevertheless get a reduced result, because it has been implemented like this:

        return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")

instead of

        return torch.norm(y_mag - x_mag, p="fro", dim=(-1, -2), keepdim=True) / torch.norm(y_mag, p="fro", dim=(-1, -2), keepdim=True)

and noticing this also made me wonder if the denominator of the current implementation makes sense. I know that few other works like ParallelWaveGAN implemented this loss in the same way but still I find it rather counterintuitive that it's globally computed instead of "per instance".