Closed IvanDSM closed 1 year ago
Hi @IvanDSM,
One option would be to move the channels to the batch dimension before computing the output. Here is a simple example demonstrating this with some dummy data.
import torch
import auraloss
batch_size = 8
chs = 2
seq_len = 131072
sample_rate = 44100
y_hat = torch.randn(batch_size, chs, seq_len)
y = torch.randn(batch_size, chs, seq_len)
print(y_hat.shape, y.shape)
firfilter = auraloss.perceptual.FIRFilter(fs=sample_rate)
# reshape both by moving channels to the batch dimension
y_hat = y_hat.view(batch_size * chs, 1, -1)
y = y.view(batch_size * chs, 1, -1)
print(y_hat.shape, y.shape)
# apply the filter
y_hat_out, y_out = firfilter(y_hat, y)
print(y_hat_out.shape, y_out.shape)
# move the channels back
y_hat_out = y_hat_out.view(batch_size, chs, -1)
y_out = y_out.view(batch_size, chs, -1)
print(y_hat_out.shape, y_out.shape)
Will this work for your use case?
Closing for now.
I attempted to use FIRFilter on batched stereo input and got this error:
Would it be possible to implement FIRFilter in a way that can handle stereo or even multichannel audio?
Thanks a lot for the library by the way, it's great!