KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
13.87k stars 1.25k forks source link

Curve fit using KAN resulting in constant loss and no improvement #278

Closed SuperKogito closed 2 weeks ago

SuperKogito commented 1 month ago

I am trying to predict certain function coefficients (output: a, b) based on its curve (input: frequency_response) with the help of https://github.com/Blealtan/efficient-kan (Kolmogorov-Arnold Network).

enter image description here

Unfortunately my loss is constant and is not improving at all. Any idea what am I doing wrong here? This problem was previously approached using MLP, hence I am hoping KANs can provide a better solution. My code is the following:


import time
import torch
import numpy as np
from tqdm import tqdm
from torchaudio.functional import lfilter
from torch.optim import Adam, lr_scheduler

from efficient_kan.kan import KAN

# Set the device
hardware = "cpu"
device = torch.device(hardware)

class FilterNet(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_batches=1, num_biquds=1, num_layers=1, fs=44100):
        super(FilterNet, self).__init__()
        self.eps = 1e-8
        self.fs = fs
        self.dirac = self.get_dirac(fs, 0, grad=True)  # generate a dirac
        self.kan = KAN([input_size, hidden_size, output_size], grid_size=5, spline_order=3)

    def get_dirac(self, size, index=1, grad=False):
        tensor = torch.zeros(size, requires_grad=grad)
        tensor.data[index] = 1
        return tensor

    def compute_filter_magnitude_and_phase_frequency_response(self, dirac, fs, a, b):
        # filter it 
        filtered_dirac = lfilter(dirac, a, b) 
        freqs_response = torch.fft.fft(filtered_dirac)

        # compute the frequency axis (positive frequencies only)
        freqs_rad = torch.fft.rfftfreq(filtered_dirac.shape[-1])

        # keep only the positive freqs
        freqs_hz = freqs_rad[:filtered_dirac.shape[-1] // 2] * fs / np.pi
        freqs_response = freqs_response[:len(freqs_hz)]

        # magnitude response 
        mag_response_db = 20 * torch.log10(torch.abs(freqs_response))

        # Phase Response
        phase_response_rad = torch.angle(freqs_response)
        phase_response_deg = phase_response_rad * 180 / np.pi
        return freqs_hz, mag_response_db, phase_response_deg

    def zpk2ba(self, zpk):
        gain    = zpk[0]
        p0_real = zpk[1]
        p0_imag = zpk[2]
        q0_real = zpk[3]
        q0_imag = zpk[4]

        zero = torch.complex(q0_real, q0_imag)
        zero_abs = zero.abs()
        zero = ((1 - self.eps) * zero * torch.tanh(zero_abs)) / (zero_abs + self.eps)

        pole = torch.complex(p0_real, p0_imag)
        pole_abs = pole.abs()
        pole = ((1 - self.eps) * pole * torch.tanh(pole_abs)) / (pole_abs + self.eps)

        b0 = gain 
        b1 = gain * -2 * zero.real
        b2 = gain * ((zero.real ** 2) + (zero.imag ** 2))
        a0 = 1
        a1 = -2 * pole.real
        a2 = (pole.real ** 2) + (pole.imag ** 2)
        b = torch.tensor([b0, b1, b2], requires_grad=True)
        a = torch.tensor([a0, a1, a2], requires_grad=True)
        return b, a

    def forward(self, x):
        zpk = self.kan(x)
        #print("> Zpk: ", zpk)

        # extract filter coeffs
        self.a, self.b = self.zpk2ba(zpk)

        # get filter reponse 
        freqs_hz, mag_response_db, phase_response_deg = self.compute_filter_magnitude_and_phase_frequency_response(self.dirac, self.fs, self.a, self.b)
        frequency_response = torch.hstack((mag_response_db, phase_response_deg))
        return frequency_response

# Define the target filter variables
fs = 512               # Sampling frequency
num_biquads = 1        # Number of biquad filters in the cascade
num_biquad_coeffs = 6  # Number of coefficients per biquad

# Init the optimizer 
n_epochs  = 500
model     = FilterNet(fs, fs*8, 5, 1, 1, 1, fs)
optimizer = Adam(model.parameters(), lr=1e-1, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

# define filter coeffs
b = torch.tensor([0.803, -0.132, 0.731])
a = torch.tensor([1.000, -0.426, 0.850])

# compute filter response
freqs_hz, mag_response_db, phase_response_deg = model.compute_filter_magnitude_and_phase_frequency_response(model.get_dirac(fs, 0, grad=False), fs, a, b)
target_frequency_response = torch.hstack((mag_response_db, phase_response_deg))

# Inits
start_time = time.time()    # Start timing the loop
pbar = tqdm(total=n_epochs) # Create a tqdm progress bar
loss_history = []

# Run training
for epoch in range(n_epochs):    
    model.train()
    device = next(model.parameters()).device

    target = target_frequency_response.to(device)
    optimizer.zero_grad()

    # Compute prediction and loss
    predicted_frequency_response = model(target)
    loss = torch.nn.functional.mse_loss(predicted_frequency_response, target_frequency_response)

    # Backpropagation
    loss.backward()
    optimizer.step()
    loss_history.append(loss.item())

    # Update the progress bar
    pbar.set_description(f"Epoch: {epoch}, Loss: {loss:.9f}")
    pbar.update(1)
    scheduler.step(loss)

# End timing the loop & print duration
elapsed_time = time.time() - start_time
print(f"\nOptimization loop took {elapsed_time:.2f} seconds.")

# Plot predicted filter
freqs_hz, predicted_mag_response_db, predicted_phase_response_deg = model.compute_filter_magnitude_and_phase_frequency_response(model.get_dirac(fs, 0, grad=False), fs, model.a.detach().cpu(), model.b.detach().cpu())
KindXiaoming commented 1 month ago

Hi, this line is a bit suspicious predicted_frequency_response = model(target) do you really want to pass target as the input?

SuperKogito commented 1 month ago

I think so, in a way I am not exactly trying to train but rather to fit the parameters for the target. However, I am really not sure about this. The specific curve I am trying to predict here is: eq which is the same as: Screenshot from 2024-06-23 01-17-19 so my idea is essentially to use KAN network to do some guided sampling of the complex space, compute the a, b coefficients then the prediction then back propagate the resulting loss for the next iteration.

However, the system is stuck in the same loss and there no learning or update happening. I will appreciate any suggestions here, as I am not sure how to proceed at this point. thank you :)

Commit2Cosmos commented 4 weeks ago

Hi, this is not a solution, rather an attempt to provide some further insight into this issue. The gradients after loss.backward() are None for the all of the KAN's layers at every epoch, including the first one. This does not happen if zpk is returned straight away (with FilterNet's output_size set to fs):

def forward(self, x):
        zpk = self.kan(x)
        return zpk

Instead, after a few epochs, the last layers' gradients go to zero, and the rest are not unusual. The loss in this case settles on 253.68, which makes me think there's something about the zpk2ba or compute_filter_magnitude_and_phase_frequency_response functions and their integration with the KAN that does not allow the gradients to be computed properly. Any thoughts?

SuperKogito commented 4 weeks ago

@KindXiaoming and @Commit2Cosmos, Thank you both for taking the time to answer me. Unfortunately, I tried these different approaches but it still showed no difference. Here is some new code, where I implement the following:

  1. I generate some data (input and output), that I train the network with, instead of feeding in the target function. This is based on @KindXiaoming answer above.
  2. I also simplified the forward function to return directly the coefficients without using any post processing, that might ruin the gradient computations.
import time
import torch
import numpy as np
from tqdm import tqdm
from torchaudio.functional import lfilter
from torch.optim import Adam, lr_scheduler

from efficient_kan.kan import KAN

# Set the device
hardware = "cpu"
device = torch.device(hardware)

class FilterNet(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_batches=1, num_biquads=1, num_layers=1, fs=44100):
        super(FilterNet, self).__init__()
        self.eps = 1e-8
        self.fs = fs
        self.dirac = self.get_dirac(fs, 0, grad=True)  # generate a dirac
        self.kan = KAN([input_size, hidden_size, output_size], grid_size=5, spline_order=3)

    def get_dirac(self, size, index=1, grad=False):
        tensor = torch.zeros(size, requires_grad=grad)
        tensor.data[index] = 1
        return tensor

    def compute_filter_magnitude_and_phase_frequency_response(self, dirac, fs, a, b):
        # filter it 
        filtered_dirac = lfilter(dirac, a, b) 
        freqs_response = torch.fft.fft(filtered_dirac)

        # compute the frequency axis (positive frequencies only)
        freqs_rad = torch.fft.rfftfreq(filtered_dirac.shape[-1])

        # keep only the positive freqs
        freqs_hz = freqs_rad[:filtered_dirac.shape[-1] // 2] * fs / np.pi
        freqs_response = freqs_response[:len(freqs_hz)]

        # magnitude response 
        mag_response_db = 20 * torch.log10(torch.abs(freqs_response))

        # Phase Response
        phase_response_rad = torch.angle(freqs_response)
        phase_response_deg = phase_response_rad * 180 / np.pi
        return freqs_hz, mag_response_db, phase_response_deg

    def forward(self, x):
        ba = self.kan(x)
        return ba

# Define the target filter variables
fs = 2048 # 44100             # Sampling frequency
num_biquads = 1        # Number of biquad filters in the cascade
num_biquad_coeffs = 6  # Number of coefficients per biquad

# define filter coeffs
b = torch.tensor([0.803, -0.132, 0.731])
a = torch.tensor([1.000, -0.426, 0.850])

# prepare data
import scipy.signal as signal 
f0 = 20
f1 = 20e3
t = np.linspace(0, 60, fs, dtype=np.float32)
sine_sweep   = signal.chirp(t=t, f0=f0, t1=60, f1=f1, method='logarithmic')
white_noise  = np.random.normal(scale=5e-2, size=len(t)) 
noisy_sweep  = sine_sweep + white_noise
train_input  = torch.from_numpy(noisy_sweep.astype(np.float32))
train_target = lfilter(train_input, a, b) 

# Init the optimizer 
n_epochs    = 3
batche_size = 1
seq_length  = 512
model     = FilterNet(seq_length, 8*seq_length, 6, batche_size, num_biquads, 1, fs)
optimizer = Adam(model.parameters(), lr=1e-1, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
criterion = torch.nn.MSELoss()

# compute filter response
freqs_hz, mag_response_db, phase_response_deg = model.compute_filter_magnitude_and_phase_frequency_response(model.get_dirac(fs, 0, grad=False), fs, a, b)
target_frequency_response = torch.hstack((mag_response_db, phase_response_deg))

# Inits
start_time = time.time()    # Start timing the loop
pbar = tqdm(total=n_epochs) # Create a tqdm progress bar
loss_history = []

# data batching 
num_sequences = int(train_input.shape[0] / seq_length)

# Run training
for epoch in range(n_epochs):    
    model.train()
    device = next(model.parameters()).device
    print("Epoch : ", epoch)
    total_loss = 0
    for seq_id in range(num_sequences):
        start_idx = seq_id*256
        end_idx   = seq_id*256 + seq_length 
        # print(seq_id, start_idx, end_idx)

        input_seq_batch  = train_input[start_idx:end_idx].unsqueeze(0).to(device)
        target_seq_batch = train_target[start_idx:end_idx].unsqueeze(0).to(device)        
        optimizer.zero_grad()

        # Compute prediction and loss
        ba = model(input_seq_batch)       
        # print("ba: ", ba)
        # model.zpk2ba(predicted_output)
        filtered_x = lfilter(input_seq_batch, ba[:, 3:], ba[:, :3])
        batch_loss = torch.nn.functional.mse_loss(filtered_x, target_seq_batch)

        # Backpropagation
        #print("~"*25)
        #print(batch_loss.grad)
        batch_loss.backward()
        #print(batch_loss.grad)
        #print("~"*25)
        optimizer.step()
        total_loss += batch_loss.item()
        print(seq_id, ":", batch_loss.item())

    # record loss
    epoch_loss = total_loss / num_sequences
    loss_history.append(epoch_loss)
    print("-"* 100)
    print(epoch_loss, "|", total_loss)

    # Update the progress bar
    #pbar.set_description(f"\nEpoch: {epoch}, Loss: {epoch_loss:.9f}\n")
    #pbar.update(1)
    scheduler.step(total_loss)
    print("*"* 100)

# End timing the loop & print duration
elapsed_time = time.time() - start_time
print(f"\nOptimization loop took {elapsed_time:.2f} seconds.")

# Plot predicted filter
predicted_a = ba[:, 3:].detach().cpu()
predicted_b = ba[:, :3].detach().cpu()
freqs_hz, predicted_mag_response_db, predicted_phase_response_deg = model.compute_filter_magnitude_and_phase_frequency_response(model.get_dirac(fs, 0, grad=False), fs, predicted_a, predicted_b)

Unfortunately, this results in a nan loss and gradient 😢

KindXiaoming commented 4 weeks ago

can't eyeball the problem real quick. Could you please try replace kan with mlp and see if it works? (just want to make sure if the bug is in kan or in other parts?)

SuperKogito commented 3 weeks ago

Using an mlp instead of kan results in nan values for the loss and the gradient too.

SuperKogito commented 3 weeks ago

Here is a more clean version with better debugging and using mlp :

import time
import torch
import torchaudio
import numpy as np
from tqdm import tqdm
from torchaudio.functional import lfilter
from torch.optim import Adam, lr_scheduler

from efficient_kan.kan import KAN

# Set the device
hardware = "cpu"
device = torch.device(hardware)

class FilterNet(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_batches=1, num_biquads=1, num_layers=1, fs=44100):
        super(FilterNet, self).__init__()
        self.eps = 1e-8
        self.fs = fs
        self.dirac = self.get_dirac(fs, 0, grad=True)  # generate a dirac
        self.mlp = torch.nn.Sequential(torch.nn.Linear(input_size, 100),
                                        torch.nn.ReLU(),
                                        torch.nn.Linear(100, 50),
                                        torch.nn.ReLU(),
                                        torch.nn.Linear(50, output_size))
        self.sos = torch.rand(num_biquads, 6, device=hardware, dtype=torch.float32, requires_grad=True)

    def get_dirac(self, size, index=1, grad=False):
        tensor = torch.zeros(size, requires_grad=grad)
        tensor.data[index] = 1
        return tensor

    def compute_filter_magnitude_and_phase_frequency_response(self, dirac, fs, a, b):
        # filter it 
        filtered_dirac = lfilter(dirac, a, b) 
        freqs_response = torch.fft.fft(filtered_dirac)

        # compute the frequency axis (positive frequencies only)
        freqs_rad = torch.fft.rfftfreq(filtered_dirac.shape[-1])

        # keep only the positive freqs
        freqs_hz = freqs_rad[:filtered_dirac.shape[-1] // 2] * fs / np.pi
        freqs_response = freqs_response[:len(freqs_hz)]

        # magnitude response 
        mag_response_db = 20 * torch.log10(torch.abs(freqs_response))

        # Phase Response
        phase_response_rad = torch.angle(freqs_response)
        phase_response_deg = phase_response_rad * 180 / np.pi
        return freqs_hz, mag_response_db, phase_response_deg

    def forward(self, x):
        self.sos = self.mlp(x)
        return self.sos

# Define the target filter variables
fs = 2048 # 44100             # Sampling frequency
num_biquads = 1        # Number of biquad filters in the cascade
num_biquad_coeffs = 6  # Number of coefficients per biquad

# define filter coeffs
target_sos = torch.tensor([0.803, -0.132, 0.731, 1.000, -0.426, 0.850])
a = target_sos[3:]
b = target_sos[:3]

# prepare data
import scipy.signal as signal 
f0 = 20
f1 = 20e3
t = np.linspace(0, 60, fs, dtype=np.float32)
sine_sweep   = signal.chirp(t=t, f0=f0, t1=60, f1=f1, method='logarithmic')
white_noise  = np.random.normal(scale=5e-2, size=len(t)) 
noisy_sweep  = sine_sweep + white_noise
train_input  = torch.from_numpy(noisy_sweep.astype(np.float32))
train_target = lfilter(train_input, a, b) 

# Init the optimizer 
n_epochs    = 9
batche_size = 1
seq_length  = 512
seq_step    = 512
model     = FilterNet(seq_length, 10*seq_length, 6, batche_size, num_biquads, 1, fs)
optimizer = Adam(model.parameters(), lr=1e-1, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
criterion = torch.nn.MSELoss()

# compute filter response
freqs_hz, mag_response_db, phase_response_deg = model.compute_filter_magnitude_and_phase_frequency_response(model.get_dirac(fs, 0, grad=False), fs, a, b)
target_frequency_response = torch.hstack((mag_response_db, phase_response_deg))

# Inits
start_time = time.time()    # Start timing the loop
pbar = tqdm(total=n_epochs) # Create a tqdm progress bar
loss_history = []

# data batching 
num_sequences = int(train_input.shape[0] / seq_length)

# Run training
for epoch in range(n_epochs):    
    model.train()
    device = next(model.parameters()).device
    print("\n+ Epoch : ", epoch)
    total_loss = 0
    for seq_id in range(num_sequences):
        start_idx = seq_id*seq_step
        end_idx   = seq_id*seq_step + seq_length 
        # print(seq_id, start_idx, end_idx)

        input_seq_batch  = train_input[start_idx:end_idx].unsqueeze(0).to(device)
        target_seq_batch = train_target[start_idx:end_idx].unsqueeze(0).to(device)        
        optimizer.zero_grad()

        # Compute prediction and loss
        sos = model(input_seq_batch)
        y = lfilter(waveform=input_seq_batch, b_coeffs=sos[:, :3], a_coeffs=sos[:, 3:])
        batch_loss = torch.nn.functional.mse_loss(y, target_seq_batch)

        sos.requires_grad_(True)
        y.requires_grad_(True)
        batch_loss.requires_grad_(True)

        print("|-> y                            : ", y.grad)
        print("|-> sos                          : ", sos.grad)
        print("|-> batch_loss (before backprop) : ", batch_loss.grad)

        # Backpropagation

        batch_loss.backward()
        print("|-> batch_loss (after backprop)  : ", batch_loss.grad)

        optimizer.step()
        total_loss += batch_loss.item()
        print(f"|=========> Sequence {seq_id}: Loss = {batch_loss.item():.9f}")

    # record loss
    epoch_loss = total_loss / num_sequences
    loss_history.append(epoch_loss)
    print("-"* 100)
    print(f"|=========> epoch_loss = {epoch_loss:.3f} | Loss = {epoch_loss:.3f}")

    # Update the progress bar
    #pbar.set_description(f"\nEpoch: {epoch}, Loss: {epoch_loss:.9f}\n")
    #pbar.update(1)
    scheduler.step(total_loss)
    print("*"* 100)

# End timing the loop & print duration
elapsed_time = time.time() - start_time
print(f"\nOptimization loop took {elapsed_time:.2f} seconds.")

# Plot predicted filter
predicted_a = model.sos[:, 3:].detach().cpu().T.squeeze(1)
predicted_b = model.sos[:, :3].detach().cpu().T.squeeze(1)
freqs_hz, predicted_mag_response_db, predicted_phase_response_deg = model.compute_filter_magnitude_and_phase_frequency_response(model.get_dirac(fs, 0, grad=False), fs, predicted_a, predicted_b)

somehow all the gradients are None and the loss get's quickly to nan. When I remove the lfilter and use batch_loss = torch.nn.functional.mse_loss(y, target_seq_batch) for the loss, the algorithm converges.

Output

+ Epoch :  0
|-> y                            :  None
|-> sos                          :  None
|-> batch_loss (before backprop) :  None
|-> batch_loss (after backprop)  :  None
|=========> Sequence 0: Loss = 1.100062847
|-> y                            :  None
|-> sos                          :  None
|-> batch_loss (before backprop) :  None
|-> batch_loss (after backprop)  :  None
|=========> Sequence 1: Loss = 1.415661454
|-> y                            :  None
|-> sos                          :  None
|-> batch_loss (before backprop) :  None
|-> batch_loss (after backprop)  :  None
|=========> Sequence 2: Loss = nan
|-> y                            :  None
|-> sos                          :  None
|-> batch_loss (before backprop) :  None
|-> batch_loss (after backprop)  :  None
|=========> Sequence 3: Loss = nan
----------------------------------------------------------------------------------------------------
|=========> epoch_loss = nan | Loss = nan
****************************************************************************************************
.
.
.
KindXiaoming commented 3 weeks ago

hmm, looks like the loss is not nan at first but only become nan later (it's a bit subtle since you are using batch). It's likely that the loss is nan for some samples but not all. Could you please compute per-sample loss (not the averaged loss), and see which samples have nan loss?

SuperKogito commented 2 weeks ago

With a lower learning rate, shorter sequences, added stability constraint and bit more fine tuning, I managed to have a stable run and a more smoothly decaying loss.