csteinmetz1 / auraloss

Collection of audio-focused loss functions in PyTorch
Apache License 2.0
695 stars 66 forks source link

scale_invariance broken #56

Open turian opened 1 year ago

turian commented 1 year ago

I just enabled scale invariance = True with MultiResolutionSTFTLoss. I am getting this error both with 0.3.0 and main.

The size of the tensors passed in are:

torch.Size([2, 7307212]) torch.Size([2, 7307212])

However, now I start getting errors:

...
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/auraloss/freq.py", line 371, in forward
    mrstft_loss += f(x, y)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/auraloss/freq.py", line 202, in forward
    y_mag = y_mag * alpha.unsqueeze(-1)
RuntimeError: The size of tensor a (513) must match the size of tensor b (2) at non-singleton dimension 1
csteinmetz1 commented 1 year ago

scale_invariance is currently still a bit experimental, but shouldn't throw an error.

Let me try to reproduce.

turian commented 1 year ago

Hacking this quickly @khumairraj suggests this fix:

            y_mag = y_mag * alpha.unsqueeze(-1).unsqueeze(-1)