csteinmetz1 / IIRNet

Direct design of biquad filter cascades with deep learning by sampling random polynomials.
https://csteinmetz1.github.io/IIRNet/
Apache License 2.0
77 stars 10 forks source link

high error by long magnitude responses #15

Open SuperKogito opened 5 months ago

SuperKogito commented 5 months ago

Hello Christian,

thank you for sharing this project. I am currently testing IIRNet with some magnitude responses I have and trying to compare the results with an optimisation based approach. Unfortunately, the results by IIRNet are way off. Is there a way to fine tune the output (without retraining). Or am I doing something wrong here. My code looks like this:

import time 
import torch
import numpy as np
import scipy.signal
import matplotlib.pyplot as plt
from scipy.signal import resample
from iirnet.designer import Designer

# ... prepare magnitude response 

# first load IIRNet with pre-trained weights
designer = Designer()

# Find IIR
f_in_hz = np.fft.rfftfreq(n=Nfft,d=1/ms.fs)
for nls in range(numLS):
    t0 = time.time()
    fmin_in_hz = 24
    fmax_in_hz = 16000

    target_eq_dB = lh.mag2db(np.abs(f_EQs[nls,lh.freq2index(fmin_in_hz,Nfft):lh.freq2index(fmax_in_hz,Nfft)]))
    f_in_hz = f_in_hz[(f_in_hz>fmin_in_hz) & (f_in_hz<fmax_in_hz)]

    EQopt = EQOptimizer()
    max_num_filters = 16
    eq_IIR, _, iir_fc, iir_Q, iir_G, iir_coeffs_a, iir_coeffs_b, f_axis = EQopt.optimize_biquad_filters(frequency=f_in_hz, target=target_eq_dB,  max_num_filters=max_num_filters, fs=float(ms.fs), logf=False)
    print("Number of biquads: "+str(iir_fc.shape[0]))
    print("OPTIMIZER Runtime: " + str((time.time() - t0)*1000))

    # Define the subplots
    fig, axs = plt.subplots(len([4, 8, 16, 32, 64]), 1, figsize=(10, 5*len([4, 8, 16, 32, 64])))

    # Use IIRNet to compute IIR filter
    for i, n in enumerate([4, 8, 16, 32, 64]):
        t0 = time.time()
        m = target_eq_dB  # Magnitude response specification
        mode = "linear"  # interpolation mode for specification
        output = "sos"   # Output type ("sos", or "ba")

        # now call the designer with parameters
        sos = designer(n, m, mode=mode, output=output)

        # measure and plot the response
        w, h = scipy.signal.sosfreqz(sos.numpy(), fs=float(ms.fs))

        # interpolate the target for plotting
        m_int = torch.tensor(m).view(1, 1, -1).float()
        m_int = torch.nn.functional.interpolate(m_int, target_eq_dB.shape[0], mode=mode)

        # Upsample the array
        upsample_factor = target_eq_dB.shape[0] / len(h)
        upsampled_array = resample(h, int(len(h) * upsample_factor))

        axs[i].semilogx(f_axis[0, :], target_eq_dB)
        axs[i].semilogx(f_axis[0, :], eq_IIR)
        axs[i].semilogx(f_axis[0, :], 20 * np.log10(np.abs(upsampled_array)))
        axs[i].legend(("Target", "IIR-optimizer", "IIR-IIRNet"))
        axs[i].set_xlabel('$f$ in Hz')
        axs[i].set_ylabel(r'$|H(f)|$')
        axs[i].grid(True)
        axs[i].set_xlim([fmin_in_hz, fmax_in_hz])
        axs[i].set_ylim([-5., 5.])
        axs[i].set_title("Idx: " + str(nls) + "| IIR order: " + str(iir_fc.shape[0]) + "| IIRnet order: " + str(n))

    print("IIRNET Runtime: " + str((time.time() - t0)*1000))
    print("+"*30)

    plt.tight_layout()
    plt.savefig(str(nls) + '.png')

Unfortunately, the inference results are not at all close to the expected values/ response. They seem to capture some attribute of the curve but the result is not usable :/ any ideas if I am doing something wrong here? or maybe how to possibly get better results?

SuperKogito commented 5 months ago

Here is one of my plots. The optimizer is unfortunately results is more accurate even tho its order is lower. received_1452473312365983

I also tried to interpolate instead of up-sampling using Scipy but this did not affect the results.

from scipy.interpolate import interp1d

# Target x-coordinate values after extrapolation
x = np.arange(len(h))
target_x = np.linspace(0, len(h) - 1, target_eq_dB.shape[0])

# Linear interpolation
interpolator = interp1d(x, h, kind='linear')
extrapolated_array = interpolator(target_x)
csteinmetz1 commented 5 months ago

Thanks for raising this. I can't reproduce your code example due to some variables that aren't included. If I had to take some guesses, you may get bad results if the input specification isn't handled properly. It should be in magnitude dB and it should be specified in normalized frequency points. If you can provide a more minimal code example that reproduces the IIRNet function call issue I can try to debug on my end.

SuperKogito commented 5 months ago

I tried to normalise the frequencies range but the results are the same. I made a minimal example to replicate this behaviour iirnet_minimal_example. Thank you for taking the time to help me with this :)

csteinmetz1 commented 5 months ago

Thanks for putting that together. I was able to validate that the solutions you are getting from IIRNet at in fact what the model is producing. It seems that the particular target you are trying to fit is out of distribution for the model. This is a known limitation of IIRNet. I am curious where your target curve comes from. Perhaps based on that we can understand why the model does not generalize well. The fits from the IIR optimizer seem quite good to me. Is there a particular reason you were looking to use IIRNet over the optimizer?

SuperKogito commented 5 months ago

The target curve is an equalisation curve based on some speaker measurements.
Even though the optimiser seems to perform better in this case, it still suffers from the following drawbacks: