Blealtan / efficient-kan

An efficient pure-PyTorch implementation of Kolmogorov-Arnold Network (KAN).
MIT License
3.49k stars 306 forks source link

Constant loss, network is not learning #48

Open SuperKogito opened 1 week ago

SuperKogito commented 1 week ago

I am trying to predict certain function coefficients (output: a, b) based on its curve (input: frequency_response) with the help of Kolmogorov-Arnold Network and your nice library.

enter image description here

Unfortunately my loss is constant and is not improving at all. Any idea what am I doing wrong here? This problem was previously approached using MLP, hence I am hoping KANs can provide a better solution. My code is the following:


import time
import torch
import numpy as np
from tqdm import tqdm
from torchaudio.functional import lfilter
from torch.optim import Adam, lr_scheduler

from efficient_kan.kan import KAN

# Set the device
hardware = "cpu"
device = torch.device(hardware)

class FilterNet(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_batches=1, num_biquds=1, num_layers=1, fs=44100):
        super(FilterNet, self).__init__()
        self.eps = 1e-8
        self.fs = fs
        self.dirac = self.get_dirac(fs, 0, grad=True)  # generate a dirac
        self.kan = KAN([input_size, hidden_size, output_size], grid_size=5, spline_order=3)

    def get_dirac(self, size, index=1, grad=False):
        tensor = torch.zeros(size, requires_grad=grad)
        tensor.data[index] = 1
        return tensor

    def compute_filter_magnitude_and_phase_frequency_response(self, dirac, fs, a, b):
        # filter it 
        filtered_dirac = lfilter(dirac, a, b) 
        freqs_response = torch.fft.fft(filtered_dirac)

        # compute the frequency axis (positive frequencies only)
        freqs_rad = torch.fft.rfftfreq(filtered_dirac.shape[-1])

        # keep only the positive freqs
        freqs_hz = freqs_rad[:filtered_dirac.shape[-1] // 2] * fs / np.pi
        freqs_response = freqs_response[:len(freqs_hz)]

        # magnitude response 
        mag_response_db = 20 * torch.log10(torch.abs(freqs_response))

        # Phase Response
        phase_response_rad = torch.angle(freqs_response)
        phase_response_deg = phase_response_rad * 180 / np.pi
        return freqs_hz, mag_response_db, phase_response_deg

    def zpk2ba(self, zpk):
        gain    = zpk[0]
        p0_real = zpk[1]
        p0_imag = zpk[2]
        q0_real = zpk[3]
        q0_imag = zpk[4]

        zero = torch.complex(q0_real, q0_imag)
        zero_abs = zero.abs()
        zero = ((1 - self.eps) * zero * torch.tanh(zero_abs)) / (zero_abs + self.eps)

        pole = torch.complex(p0_real, p0_imag)
        pole_abs = pole.abs()
        pole = ((1 - self.eps) * pole * torch.tanh(pole_abs)) / (pole_abs + self.eps)

        b0 = gain 
        b1 = gain * -2 * zero.real
        b2 = gain * ((zero.real ** 2) + (zero.imag ** 2))
        a0 = 1
        a1 = -2 * pole.real
        a2 = (pole.real ** 2) + (pole.imag ** 2)
        b = torch.tensor([b0, b1, b2], requires_grad=True)
        a = torch.tensor([a0, a1, a2], requires_grad=True)
        return b, a

    def forward(self, x):
        zpk = self.kan(x)
        #print("> Zpk: ", zpk)

        # extract filter coeffs
        self.a, self.b = self.zpk2ba(zpk)

        # get filter reponse 
        freqs_hz, mag_response_db, phase_response_deg = self.compute_filter_magnitude_and_phase_frequency_response(self.dirac, self.fs, self.a, self.b)
        frequency_response = torch.hstack((mag_response_db, phase_response_deg))
        return frequency_response

# Define the target filter variables
fs = 512               # nbr of input points
num_biquads = 1        # Number of biquad filters in the cascade
num_biquad_coeffs = 6  # Number of coefficients per biquad

# Init the optimizer 
n_epochs  = 500
model     = FilterNet(fs, fs*8, 5, 1, 1, 1, fs)
optimizer = Adam(model.parameters(), lr=1e-1, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

# define filter coeffs (TARGET)
b = torch.tensor([0.803, -0.132, 0.731])
a = torch.tensor([1.000, -0.426, 0.850])

# compute filter response
freqs_hz, mag_response_db, phase_response_deg = model.compute_filter_magnitude_and_phase_frequency_response(model.get_dirac(fs, 0, grad=False), fs, a, b)
target_frequency_response = torch.hstack((mag_response_db, phase_response_deg))

# Inits
start_time = time.time()    # Start timing the loop
pbar = tqdm(total=n_epochs) # Create a tqdm progress bar
loss_history = []

# Run training
for epoch in range(n_epochs):    
    model.train()
    device = next(model.parameters()).device

    target = target_frequency_response.to(device)
    optimizer.zero_grad()

    # Compute prediction and loss
    predicted_frequency_response = model(target)
    loss = torch.nn.functional.mse_loss(predicted_frequency_response, target_frequency_response)

    # Backpropagation
    loss.backward()
    optimizer.step()
    loss_history.append(loss.item())

    # Update the progress bar
    pbar.set_description(f"Epoch: {epoch}, Loss: {loss:.9f}")
    pbar.update(1)
    scheduler.step(loss)

# End timing the loop & print duration
elapsed_time = time.time() - start_time
print(f"\nOptimization loop took {elapsed_time:.2f} seconds.")

# Plot predicted filter
freqs_hz, predicted_mag_response_db, predicted_phase_response_deg = model.compute_filter_magnitude_and_phase_frequency_response(model.get_dirac(fs, 0, grad=False), fs, model.a.detach().cpu(), model.b.detach().cpu())
mw66 commented 5 days ago

@SuperKogito Some questions:

1) are you able to get good training results from MLP ? What's your training graph look like?

2) have you tried the official KAN implementation? And what is your training result?

SuperKogito commented 5 days ago

This seems to be unstable and not related to efficient_kan, as I am not able to get good training results using MLP :/ No I have not tried with the official KAN.