csteinmetz1 / auraloss

Collection of audio-focused loss functions in PyTorch
Apache License 2.0
745 stars 67 forks source link

`FIRFilter` Class Failed when running on CPU (but not on CUDA) #62

Closed int0thewind closed 1 year ago

int0thewind commented 1 year ago

Hi!

As recommended by Välimäki et al., a pre-emphasis filter could be applied before applying ESR loss. A auraloss.perceptual.FIRFilter instance, however, cannot be successfully called when the PyTorch device is CPU. Interestingly, the instance can be called on a Nvidia CUDA device without any runtime error.

Expected Behavior

auraloss.perceptual.FIRFilter instance can be successfully called regardless of any device.

Current Behavior

When calling an auraloss.perceptual.FIRFilter instance on CPU, a runtime error would be raised.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[4], line 37
     33             losses[k].append(v)
     35     pd.DataFrame(losses).to_csv(job_eval_dir [/](https://file+.vscode-resource.vscode-cdn.net/) f'loss.csv')
---> 37 test()

File [~/.local/share/virtualenvs/s4-dynamic-range-compressor-WjUGfTKg/lib/python3.10/site-packages/torch/utils/_contextlib.py:115](https://file+.vscode-resource.vscode-cdn.net/Users/int0thewind/Developer/s4-dynamic-range-compressor/~/.local/share/virtualenvs/s4-dynamic-range-compressor-WjUGfTKg/lib/python3.10/site-packages/torch/utils/_contextlib.py:115), in context_decorator..decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

Cell In[4], line 26, in test()
     23     y_hat: Tensor = model(x, parameters)
     25     for validation_loss, validation_criterion in validation_criterions.items():
---> 26         loss: Tensor = validation_criterion(y_hat.unsqueeze(1), y.unsqueeze(1))
     27         validation_losses[validation_loss] += loss.item()
     29 for k, v in list(validation_losses.items()):

File [~/.local/share/virtualenvs/s4-dynamic-range-compressor-WjUGfTKg/lib/python3.10/site-packages/torch/nn/modules/module.py:1501](https://file+.vscode-resource.vscode-cdn.net/Users/int0thewind/Developer/s4-dynamic-range-compressor/~/.local/share/virtualenvs/s4-dynamic-range-compressor-WjUGfTKg/lib/python3.10/site-packages/torch/nn/modules/module.py:1501), in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File [~/Developer/s4-dynamic-range-compressor/src/loss.py:42](https://file+.vscode-resource.vscode-cdn.net/Users/int0thewind/Developer/s4-dynamic-range-compressor/~/Developer/s4-dynamic-range-compressor/src/loss.py:42), in PreEmphasisESRLoss.forward(self, y_hat, y)
     40 def forward(self, y_hat: Tensor, y: Tensor) -> Tensor:
     41     if self.pre_emphasis_filter:
---> 42         y_hat, y = self.pre_emphasis_filter(y_hat, y)
     43     return self.esr(y_hat, y)

File [~/.local/share/virtualenvs/s4-dynamic-range-compressor-WjUGfTKg/lib/python3.10/site-packages/torch/nn/modules/module.py:1501](https://file+.vscode-resource.vscode-cdn.net/Users/int0thewind/Developer/s4-dynamic-range-compressor/~/.local/share/virtualenvs/s4-dynamic-range-compressor-WjUGfTKg/lib/python3.10/site-packages/torch/nn/modules/module.py:1501), in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File [~/.local/share/virtualenvs/s4-dynamic-range-compressor-WjUGfTKg/lib/python3.10/site-packages/auraloss/perceptual.py:125](https://file+.vscode-resource.vscode-cdn.net/Users/int0thewind/Developer/s4-dynamic-range-compressor/~/.local/share/virtualenvs/s4-dynamic-range-compressor-WjUGfTKg/lib/python3.10/site-packages/auraloss/perceptual.py:125), in FIRFilter.forward(self, input, target)
    117 def forward(self, input, target):
    118     """Calculate forward propagation.
    119     Args:
    120         input (Tensor): Predicted signal (B, #channels, #samples).
   (...)
    123         Tensor: Filtered signal.
    124     """
--> 125     input = torch.nn.functional.conv1d(
    126         input, self.fir.weight.data, padding=self.ntaps [/](https://file+.vscode-resource.vscode-cdn.net/)[/](https://file+.vscode-resource.vscode-cdn.net/) 2
    127     )
    128     target = torch.nn.functional.conv1d(
    129         target, self.fir.weight.data, padding=self.ntaps [/](https://file+.vscode-resource.vscode-cdn.net/)[/](https://file+.vscode-resource.vscode-cdn.net/) 2
    130     )
    131     return input, target

RuntimeError: NNPACK SpatialConvolution_updateOutput failed

Steps to Reproduce

  1. Create an auraloss.perceptual.FIRFilter instance.
  2. Create two three-dimensional (batch size, audio_channel, sample length) float32 PyTorch Tensors with the same shape
  3. Convert all PyTorch Tensors and the FIRFilter instance to CPU device.
  4. Call the FIRFilter instance with these two tensors as parameters.

Context (Environment)

CPU: Apple M1 Max

sys.version = '3.10.12 (main, Jun 20 2023, 19:43:52) [Clang 14.0.3 (clang-1403.0.22.14.1)]'
platform.platform() = 'macOS-13.4.1-arm64-arm-64bit'
device = device(type='cpu')
torch.__version__ = '2.0.0'
csteinmetz1 commented 1 year ago

Hi, thanks for raising this issue. I did some testing and found something. The following example runs without error on my M1 Mac, but only when the batch size is less than 16. When I set bs=16, I get the same error as you reported. This does not appear to be a problem with auraloss, but instead a problem with the torch backend for CPU, specifically the convolution operation in NNPACK. For now, if you are using auraloss for evaluation on CPU, I would suggest using a smaller batch size to fix the issue. Let me know if that works.

import torch
import auraloss

bs = 2
chs = 1
seq_len = 44100

x = torch.randn(bs, chs, seq_len)
y = torch.randn(bs, chs, seq_len)

fir = auraloss.perceptual.FIRFilter()

x_out, y_out = fir(x, y)
print(x_out.shape, y_out.shape)
int0thewind commented 1 year ago

Thanks, Christian! Yes, the error would occur if the batch size is bigger or equal than 16 —interesting bug from PyTorch. Maybe I should raise it to PyTorch in the future.

Fix: the input tensor shall be three-dimensional instead of two. I fixed that in my initial post.