Open turian opened 11 months ago
I have the same issue
The reason this currently doesn't work is that the desired behavior for reduction="none"
when using the multi-resolution loss is a bit ambiguous. Since each STFTLoss
will produce a different shape both in the frequency axis and the time axis it is not possible to combine them into a single loss tensor, as is the normal behavior.
We could consider returning an list of tensors that correspond to each STFTLoss
output, but this would provide a different return type than the normal behavior which is just a single tensor. Would this address your applications @turian, @mcherep?
I think this would work for me, what I need is to have one aggregated loss per batch instead of aggregating over all dimensions. Let me know if I'm missing something!
Perhaps the best way to achieve the desired behavior right now is to manage a set of STFTLoss
instances yourself. That way you can define the behavior of how they are aggregated, in this case without any. I am hesitant to add more complexity to the possible return types of MultiResolutionSTFTLoss
class if we can avoid it. This class is really just a wrapper around STFTLoss
for convenience. Here is a potential example.
import torch
import auraloss
fft_sizes = [512, 1024, 2048]
win_lengths = [512, 1024, 2048]
hop_sizes = [256, 512, 1024]
reduction = "none"
loss_fns = torch.nn.ModuleList()
for fft_size, win_length, hop_size in zip(fft_sizes, win_lengths, hop_sizes):
loss_fns.append(auraloss.freq.STFTLoss(fft_size, hop_size, win_length, reduction=reduction))
bs = 4
chs = 1
seq_len = 131072
x = torch.randn(bs, chs, seq_len)
y = torch.randn(bs, chs, seq_len)
for loss_fn in loss_fns:
loss = loss_fn(x, y)
print(loss.shape)
outputs
torch.Size([4, 257, 513])
torch.Size([4, 513, 257])
torch.Size([4, 1025, 129])
Great, I will do this instead. Thanks so much!
I came up with another hacky way to do this while leveraging all the convenience of the MultiResolutionSTFT
class is to monkey-patch the auraloss.freq.apply_reduction
function in the following way:
auraloss.freq.apply_reduction = lambda losses, reduction: losses.mean(dim=(-1,-2))
(last two dims are STFT bins and frames, reduction averages over them, keeping every other dimension)
@csteinmetz1 Also, I noticed that for spectral convergence loss, despite the option `reduction=None, we nevertheless get a reduced result, because it has been implemented like this:
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
instead of
return torch.norm(y_mag - x_mag, p="fro", dim=(-1, -2), keepdim=True) / torch.norm(y_mag, p="fro", dim=(-1, -2), keepdim=True)
and noticing this also made me wonder if the denominator of the current implementation makes sense. I know that few other works like ParallelWaveGAN
implemented this loss in the same way but still I find it rather counterintuitive that it's globally computed instead of "per instance".
gives
What I'm really looking for is a per-instance reduction, which I can compute from reduction=none. Anyway, reduction=none is not working out of the box, which is unfortunately a showstopper for me :(