csteinmetz1 / auraloss

Collection of audio-focused loss functions in PyTorch
Apache License 2.0
695 stars 66 forks source link

fix reduction mistake in SpectralConvergenceLoss #75

Open renared opened 2 months ago

renared commented 2 months ago

I noticed that when evaluating the STFT loss over my validation dataset, I obtained different results in function of the batch size. I could isolate the cause to be the spectral convergence term, then came across the comment by @egaznep in issue #69. It does not make sense to average the denominator over all dimensions including the batch dimension, so I believe their suggestion should be used instead.

This snippet shows the difference:

import torch
from auraloss.freq import STFTLoss

batches = [(torch.randn(4, 1, 16384), torch.randn(4, 1, 16384)) for i in range(1024)]
batchall = tuple(torch.concat(u, dim=0) for u in zip(*batches))

print("with spectral convergence enabled")
loss = STFTLoss()
print("mean of losses:", torch.mean(torch.tensor(tuple(loss(*batch) for batch in batches))))
print("over full dataset:", loss(*batchall))

print("with spectral convergence disabled")
loss = STFTLoss(w_sc=0)
print("mean of losses:", torch.mean(torch.tensor(tuple(loss(*batch) for batch in batches))))
print("over full dataset:", loss(*batchall))

Before:

with spectral convergence enabled
mean of losses: tensor(1.3511)
over full dataset: tensor(1.3493)
with spectral convergence disabled
mean of losses: tensor(0.6950)
over full dataset: tensor(0.6950)

After:

with spectral convergence enabled
mean of losses: tensor(1.3726)
over full dataset: tensor(1.3726)
with spectral convergence disabled
mean of losses: tensor(0.7095)
over full dataset: tensor(0.7095)