csteinmetz1 / auraloss

Collection of audio-focused loss functions in PyTorch
Apache License 2.0
695 stars 66 forks source link

FIRFliter seems to not support stereo audio #42

Closed IvanDSM closed 1 year ago

IvanDSM commented 1 year ago

I attempted to use FIRFilter on batched stereo input and got this error:

[/usr/local/lib/python3.8/dist-packages/auraloss/perceptual.py](https://localhost:8080/#) in forward(self, input, target)
    122             Tensor: Filtered signal.
    123         """
--> 124         input = torch.nn.functional.conv1d(
    125             input, self.fir.weight.data, padding=self.ntaps // 2
    126         )

RuntimeError: Given groups=1, weight of size [1, 1, 101], expected input[96, 2, 8192] to have 1 channels, but got 2 channels instead

Would it be possible to implement FIRFilter in a way that can handle stereo or even multichannel audio?

Thanks a lot for the library by the way, it's great!

csteinmetz1 commented 1 year ago

Hi @IvanDSM,

One option would be to move the channels to the batch dimension before computing the output. Here is a simple example demonstrating this with some dummy data.

import torch
import auraloss

batch_size = 8
chs = 2
seq_len = 131072
sample_rate = 44100
y_hat = torch.randn(batch_size, chs, seq_len)
y = torch.randn(batch_size, chs, seq_len)
print(y_hat.shape, y.shape)

firfilter = auraloss.perceptual.FIRFilter(fs=sample_rate)

# reshape both by moving channels to the batch dimension
y_hat = y_hat.view(batch_size * chs, 1, -1)
y = y.view(batch_size * chs, 1, -1)
print(y_hat.shape, y.shape)

# apply the filter
y_hat_out, y_out = firfilter(y_hat, y)
print(y_hat_out.shape, y_out.shape)

# move the channels back
y_hat_out = y_hat_out.view(batch_size, chs, -1)
y_out = y_out.view(batch_size, chs, -1)
print(y_hat_out.shape, y_out.shape)

Will this work for your use case?

csteinmetz1 commented 1 year ago

Closing for now.