csteinmetz1 / auraloss

Collection of audio-focused loss functions in PyTorch
Apache License 2.0
695 stars 66 forks source link

Fir filter input #65

Open jmoso13 opened 9 months ago

jmoso13 commented 9 months ago

In this PR I add optional FIRFilter input to STFTLoss, this filter automatically fills self.prefilter if available and sets it to None if not provided. If only perceptual_weighting flag is set, self.prefilter is set with internally constructed FIRFilter.

If both an external FIRFilter is provided and perceptual_weighting flag is set, an nn.Sequential variation (that allows for two inputs) is constructed to run both filters sequentially.

Tested on some audio input and appears to be working as expected. Below are some spectrograms of the different variations.

No Filter:

audio_no_filter

Only Perceptual Weighting:

audio_only_percep_weight

Only External 4.5k Lowpass Filter:

audio_only_4k_lowpass

Both Filters:

audio_percep_weight_and_4k_lowpass

Also included are the changes to auraloss.perceptual.FIRFilter which allows for for butterworth filter construction and a FIRSequential class in auraloss.utils that inherits from nn.Sequential and allows for multiple inputs.

I haven't tried using this branch in a model yet but it has worked returning losses as expected in my testing of just this repo.