Closed SuperKogito closed 2 weeks ago
Hi, this line is a bit suspicious
predicted_frequency_response = model(target)
do you really want to pass target as the input?
I think so, in a way I am not exactly trying to train but rather to fit the parameters for the target. However, I am really not sure about this. The specific curve I am trying to predict here is:
which is the same as:
so my idea is essentially to use KAN network to do some guided sampling of the complex space, compute the a, b coefficients then the prediction then back propagate the resulting loss for the next iteration.
However, the system is stuck in the same loss and there no learning or update happening. I will appreciate any suggestions here, as I am not sure how to proceed at this point. thank you :)
Hi, this is not a solution, rather an attempt to provide some further insight into this issue. The gradients after loss.backward()
are None
for the all of the KAN's layers at every epoch, including the first one. This does not happen if zpk
is returned straight away (with FilterNet's output_size
set to fs
):
def forward(self, x):
zpk = self.kan(x)
return zpk
Instead, after a few epochs, the last layers' gradients go to zero, and the rest are not unusual. The loss in this case settles on 253.68, which makes me think there's something about the zpk2ba
or compute_filter_magnitude_and_phase_frequency_response
functions and their integration with the KAN that does not allow the gradients to be computed properly. Any thoughts?
@KindXiaoming and @Commit2Cosmos, Thank you both for taking the time to answer me. Unfortunately, I tried these different approaches but it still showed no difference. Here is some new code, where I implement the following:
import time
import torch
import numpy as np
from tqdm import tqdm
from torchaudio.functional import lfilter
from torch.optim import Adam, lr_scheduler
from efficient_kan.kan import KAN
# Set the device
hardware = "cpu"
device = torch.device(hardware)
class FilterNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_batches=1, num_biquads=1, num_layers=1, fs=44100):
super(FilterNet, self).__init__()
self.eps = 1e-8
self.fs = fs
self.dirac = self.get_dirac(fs, 0, grad=True) # generate a dirac
self.kan = KAN([input_size, hidden_size, output_size], grid_size=5, spline_order=3)
def get_dirac(self, size, index=1, grad=False):
tensor = torch.zeros(size, requires_grad=grad)
tensor.data[index] = 1
return tensor
def compute_filter_magnitude_and_phase_frequency_response(self, dirac, fs, a, b):
# filter it
filtered_dirac = lfilter(dirac, a, b)
freqs_response = torch.fft.fft(filtered_dirac)
# compute the frequency axis (positive frequencies only)
freqs_rad = torch.fft.rfftfreq(filtered_dirac.shape[-1])
# keep only the positive freqs
freqs_hz = freqs_rad[:filtered_dirac.shape[-1] // 2] * fs / np.pi
freqs_response = freqs_response[:len(freqs_hz)]
# magnitude response
mag_response_db = 20 * torch.log10(torch.abs(freqs_response))
# Phase Response
phase_response_rad = torch.angle(freqs_response)
phase_response_deg = phase_response_rad * 180 / np.pi
return freqs_hz, mag_response_db, phase_response_deg
def forward(self, x):
ba = self.kan(x)
return ba
# Define the target filter variables
fs = 2048 # 44100 # Sampling frequency
num_biquads = 1 # Number of biquad filters in the cascade
num_biquad_coeffs = 6 # Number of coefficients per biquad
# define filter coeffs
b = torch.tensor([0.803, -0.132, 0.731])
a = torch.tensor([1.000, -0.426, 0.850])
# prepare data
import scipy.signal as signal
f0 = 20
f1 = 20e3
t = np.linspace(0, 60, fs, dtype=np.float32)
sine_sweep = signal.chirp(t=t, f0=f0, t1=60, f1=f1, method='logarithmic')
white_noise = np.random.normal(scale=5e-2, size=len(t))
noisy_sweep = sine_sweep + white_noise
train_input = torch.from_numpy(noisy_sweep.astype(np.float32))
train_target = lfilter(train_input, a, b)
# Init the optimizer
n_epochs = 3
batche_size = 1
seq_length = 512
model = FilterNet(seq_length, 8*seq_length, 6, batche_size, num_biquads, 1, fs)
optimizer = Adam(model.parameters(), lr=1e-1, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
criterion = torch.nn.MSELoss()
# compute filter response
freqs_hz, mag_response_db, phase_response_deg = model.compute_filter_magnitude_and_phase_frequency_response(model.get_dirac(fs, 0, grad=False), fs, a, b)
target_frequency_response = torch.hstack((mag_response_db, phase_response_deg))
# Inits
start_time = time.time() # Start timing the loop
pbar = tqdm(total=n_epochs) # Create a tqdm progress bar
loss_history = []
# data batching
num_sequences = int(train_input.shape[0] / seq_length)
# Run training
for epoch in range(n_epochs):
model.train()
device = next(model.parameters()).device
print("Epoch : ", epoch)
total_loss = 0
for seq_id in range(num_sequences):
start_idx = seq_id*256
end_idx = seq_id*256 + seq_length
# print(seq_id, start_idx, end_idx)
input_seq_batch = train_input[start_idx:end_idx].unsqueeze(0).to(device)
target_seq_batch = train_target[start_idx:end_idx].unsqueeze(0).to(device)
optimizer.zero_grad()
# Compute prediction and loss
ba = model(input_seq_batch)
# print("ba: ", ba)
# model.zpk2ba(predicted_output)
filtered_x = lfilter(input_seq_batch, ba[:, 3:], ba[:, :3])
batch_loss = torch.nn.functional.mse_loss(filtered_x, target_seq_batch)
# Backpropagation
#print("~"*25)
#print(batch_loss.grad)
batch_loss.backward()
#print(batch_loss.grad)
#print("~"*25)
optimizer.step()
total_loss += batch_loss.item()
print(seq_id, ":", batch_loss.item())
# record loss
epoch_loss = total_loss / num_sequences
loss_history.append(epoch_loss)
print("-"* 100)
print(epoch_loss, "|", total_loss)
# Update the progress bar
#pbar.set_description(f"\nEpoch: {epoch}, Loss: {epoch_loss:.9f}\n")
#pbar.update(1)
scheduler.step(total_loss)
print("*"* 100)
# End timing the loop & print duration
elapsed_time = time.time() - start_time
print(f"\nOptimization loop took {elapsed_time:.2f} seconds.")
# Plot predicted filter
predicted_a = ba[:, 3:].detach().cpu()
predicted_b = ba[:, :3].detach().cpu()
freqs_hz, predicted_mag_response_db, predicted_phase_response_deg = model.compute_filter_magnitude_and_phase_frequency_response(model.get_dirac(fs, 0, grad=False), fs, predicted_a, predicted_b)
Unfortunately, this results in a nan
loss and gradient 😢
can't eyeball the problem real quick. Could you please try replace kan with mlp and see if it works? (just want to make sure if the bug is in kan or in other parts?)
Using an mlp instead of kan results in nan
values for the loss and the gradient too.
Here is a more clean version with better debugging and using mlp :
import time
import torch
import torchaudio
import numpy as np
from tqdm import tqdm
from torchaudio.functional import lfilter
from torch.optim import Adam, lr_scheduler
from efficient_kan.kan import KAN
# Set the device
hardware = "cpu"
device = torch.device(hardware)
class FilterNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_batches=1, num_biquads=1, num_layers=1, fs=44100):
super(FilterNet, self).__init__()
self.eps = 1e-8
self.fs = fs
self.dirac = self.get_dirac(fs, 0, grad=True) # generate a dirac
self.mlp = torch.nn.Sequential(torch.nn.Linear(input_size, 100),
torch.nn.ReLU(),
torch.nn.Linear(100, 50),
torch.nn.ReLU(),
torch.nn.Linear(50, output_size))
self.sos = torch.rand(num_biquads, 6, device=hardware, dtype=torch.float32, requires_grad=True)
def get_dirac(self, size, index=1, grad=False):
tensor = torch.zeros(size, requires_grad=grad)
tensor.data[index] = 1
return tensor
def compute_filter_magnitude_and_phase_frequency_response(self, dirac, fs, a, b):
# filter it
filtered_dirac = lfilter(dirac, a, b)
freqs_response = torch.fft.fft(filtered_dirac)
# compute the frequency axis (positive frequencies only)
freqs_rad = torch.fft.rfftfreq(filtered_dirac.shape[-1])
# keep only the positive freqs
freqs_hz = freqs_rad[:filtered_dirac.shape[-1] // 2] * fs / np.pi
freqs_response = freqs_response[:len(freqs_hz)]
# magnitude response
mag_response_db = 20 * torch.log10(torch.abs(freqs_response))
# Phase Response
phase_response_rad = torch.angle(freqs_response)
phase_response_deg = phase_response_rad * 180 / np.pi
return freqs_hz, mag_response_db, phase_response_deg
def forward(self, x):
self.sos = self.mlp(x)
return self.sos
# Define the target filter variables
fs = 2048 # 44100 # Sampling frequency
num_biquads = 1 # Number of biquad filters in the cascade
num_biquad_coeffs = 6 # Number of coefficients per biquad
# define filter coeffs
target_sos = torch.tensor([0.803, -0.132, 0.731, 1.000, -0.426, 0.850])
a = target_sos[3:]
b = target_sos[:3]
# prepare data
import scipy.signal as signal
f0 = 20
f1 = 20e3
t = np.linspace(0, 60, fs, dtype=np.float32)
sine_sweep = signal.chirp(t=t, f0=f0, t1=60, f1=f1, method='logarithmic')
white_noise = np.random.normal(scale=5e-2, size=len(t))
noisy_sweep = sine_sweep + white_noise
train_input = torch.from_numpy(noisy_sweep.astype(np.float32))
train_target = lfilter(train_input, a, b)
# Init the optimizer
n_epochs = 9
batche_size = 1
seq_length = 512
seq_step = 512
model = FilterNet(seq_length, 10*seq_length, 6, batche_size, num_biquads, 1, fs)
optimizer = Adam(model.parameters(), lr=1e-1, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
criterion = torch.nn.MSELoss()
# compute filter response
freqs_hz, mag_response_db, phase_response_deg = model.compute_filter_magnitude_and_phase_frequency_response(model.get_dirac(fs, 0, grad=False), fs, a, b)
target_frequency_response = torch.hstack((mag_response_db, phase_response_deg))
# Inits
start_time = time.time() # Start timing the loop
pbar = tqdm(total=n_epochs) # Create a tqdm progress bar
loss_history = []
# data batching
num_sequences = int(train_input.shape[0] / seq_length)
# Run training
for epoch in range(n_epochs):
model.train()
device = next(model.parameters()).device
print("\n+ Epoch : ", epoch)
total_loss = 0
for seq_id in range(num_sequences):
start_idx = seq_id*seq_step
end_idx = seq_id*seq_step + seq_length
# print(seq_id, start_idx, end_idx)
input_seq_batch = train_input[start_idx:end_idx].unsqueeze(0).to(device)
target_seq_batch = train_target[start_idx:end_idx].unsqueeze(0).to(device)
optimizer.zero_grad()
# Compute prediction and loss
sos = model(input_seq_batch)
y = lfilter(waveform=input_seq_batch, b_coeffs=sos[:, :3], a_coeffs=sos[:, 3:])
batch_loss = torch.nn.functional.mse_loss(y, target_seq_batch)
sos.requires_grad_(True)
y.requires_grad_(True)
batch_loss.requires_grad_(True)
print("|-> y : ", y.grad)
print("|-> sos : ", sos.grad)
print("|-> batch_loss (before backprop) : ", batch_loss.grad)
# Backpropagation
batch_loss.backward()
print("|-> batch_loss (after backprop) : ", batch_loss.grad)
optimizer.step()
total_loss += batch_loss.item()
print(f"|=========> Sequence {seq_id}: Loss = {batch_loss.item():.9f}")
# record loss
epoch_loss = total_loss / num_sequences
loss_history.append(epoch_loss)
print("-"* 100)
print(f"|=========> epoch_loss = {epoch_loss:.3f} | Loss = {epoch_loss:.3f}")
# Update the progress bar
#pbar.set_description(f"\nEpoch: {epoch}, Loss: {epoch_loss:.9f}\n")
#pbar.update(1)
scheduler.step(total_loss)
print("*"* 100)
# End timing the loop & print duration
elapsed_time = time.time() - start_time
print(f"\nOptimization loop took {elapsed_time:.2f} seconds.")
# Plot predicted filter
predicted_a = model.sos[:, 3:].detach().cpu().T.squeeze(1)
predicted_b = model.sos[:, :3].detach().cpu().T.squeeze(1)
freqs_hz, predicted_mag_response_db, predicted_phase_response_deg = model.compute_filter_magnitude_and_phase_frequency_response(model.get_dirac(fs, 0, grad=False), fs, predicted_a, predicted_b)
somehow all the gradients are None and the loss get's quickly to nan. When I remove the lfilter
and use batch_loss = torch.nn.functional.mse_loss(y, target_seq_batch)
for the loss, the algorithm converges.
+ Epoch : 0
|-> y : None
|-> sos : None
|-> batch_loss (before backprop) : None
|-> batch_loss (after backprop) : None
|=========> Sequence 0: Loss = 1.100062847
|-> y : None
|-> sos : None
|-> batch_loss (before backprop) : None
|-> batch_loss (after backprop) : None
|=========> Sequence 1: Loss = 1.415661454
|-> y : None
|-> sos : None
|-> batch_loss (before backprop) : None
|-> batch_loss (after backprop) : None
|=========> Sequence 2: Loss = nan
|-> y : None
|-> sos : None
|-> batch_loss (before backprop) : None
|-> batch_loss (after backprop) : None
|=========> Sequence 3: Loss = nan
----------------------------------------------------------------------------------------------------
|=========> epoch_loss = nan | Loss = nan
****************************************************************************************************
.
.
.
hmm, looks like the loss is not nan at first but only become nan later (it's a bit subtle since you are using batch). It's likely that the loss is nan for some samples but not all. Could you please compute per-sample loss (not the averaged loss), and see which samples have nan loss?
With a lower learning rate, shorter sequences, added stability constraint and bit more fine tuning, I managed to have a stable run and a more smoothly decaying loss.
I am trying to predict certain function coefficients (output: a, b) based on its curve (input: frequency_response) with the help of https://github.com/Blealtan/efficient-kan (Kolmogorov-Arnold Network).
Unfortunately my loss is constant and is not improving at all. Any idea what am I doing wrong here? This problem was previously approached using MLP, hence I am hoping KANs can provide a better solution. My code is the following: