pytorch / audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
https://pytorch.org/audio
BSD 2-Clause "Simplified" License
2.43k stars 636 forks source link

`torchaudio.functional.lfilter` returns `nan` when processing sub-array but not for the whole input array. #3807

Closed SuperKogito closed 5 hours ago

SuperKogito commented 3 days ago

🐛 Describe the bug

🐛 Describe the bug

When trying to use torchaudio.functional.lfilter to generate training data, lfilter works as expected. However, when using it in the loss computation is returns a bunch of nan values in the prediction causing my loss to become nan.

Code

import time
import torch
import math
import torchaudio
import numpy as np
from tqdm import tqdm
from torchaudio.functional import lfilter
from torch.optim import Adam, lr_scheduler

# Set the device
hardware = "cpu"
device = torch.device(hardware)
debug = False

class FilterLoss(torch.nn.Module):
    def __init__(self, debug=True):
        super().__init__()
        self.debug=debug

    def forward(self, sos, y, target_y):
        predicted_y = lfilter(y, sos[:, :3], sos[:, 3:])
        loss = torch.nn.functional.mse_loss(predicted_y, target_y)
        if self.debug:
            print("> sos: ", sos)
            print("> y: ", y)
            print("> predicted_y: ", predicted_y)
            print("> loss: ", loss)        
        return loss

class FilterNet(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_batches=1, num_biquads=1, num_layers=1, fs=44100):
        super(FilterNet, self).__init__()
        self.eps = 1e-8
        self.fs = fs
        self.dirac = self.get_dirac(fs, 0, grad=True)  # generate a dirac
        self.mlp = torch.nn.Sequential(torch.nn.Linear(input_size, 100),
                                        torch.nn.ReLU(),
                                        torch.nn.Linear(100, 50),
                                        torch.nn.ReLU(),
                                        torch.nn.Linear(50, output_size))
        #self.kan = KAN([input_size, hidden_size, output_size], grid_size=5, spline_order=3)
        self.sos = torch.rand(num_biquads, 6, device=hardware, dtype=torch.float32, requires_grad=True)
        # self.zpk = torch.rand(num_biquads, 5, device=hardware, dtype=torch.float32, requires_grad=True)

    def get_dirac(self, size, index=1, grad=False):
        tensor = torch.zeros(size, requires_grad=grad)
        tensor.data[index] = 1
        return tensor

    def compute_filter_magnitude_and_phase_frequency_response(self, dirac, fs, a, b):
        # filter it 
        filtered_dirac = lfilter(dirac, a, b) 
        freqs_response = torch.fft.fft(filtered_dirac)

        # compute the frequency axis (positive frequencies only)
        freqs_rad = torch.fft.rfftfreq(filtered_dirac.shape[-1])

        # keep only the positive freqs
        freqs_hz = freqs_rad[:filtered_dirac.shape[-1] // 2] * fs / np.pi
        freqs_response = freqs_response[:len(freqs_hz)]

        # magnitude response 
        mag_response_db = 20 * torch.log10(torch.abs(freqs_response))

        # Phase Response
        phase_response_rad = torch.angle(freqs_response)
        phase_response_deg = phase_response_rad * 180 / np.pi
        return freqs_hz, mag_response_db, phase_response_deg

    def forward(self, x):
        self.sos = self.mlp(x)
        return self.sos

# Define the target filter variables
fs = 1024              # Sampling frequency
num_biquads = 1        # Number of biquad filters in the cascade
num_biquad_coeffs = 6  # Number of coefficients per biquad

# define filter coeffs
target_sos = torch.tensor([0.803, -0.132, 0.731, 1.000, -0.426, 0.850])
a = target_sos[3:]
b = target_sos[:3]

# prepare data
import scipy.signal as signal 
f0 = 20
f1 = 20e3
t = np.linspace(0, 60, fs, dtype=np.float32)
sine_sweep   = signal.chirp(t=t, f0=f0, t1=60, f1=f1, method='logarithmic')
white_noise  = np.random.normal(scale=5e-2, size=len(t)) 
noisy_sweep  = sine_sweep + white_noise
noisy_sweep_normalized = noisy_sweep / np.max(np.abs(noisy_sweep))

train_input  = torch.from_numpy(noisy_sweep_normalized.astype(np.float32))
train_target = lfilter(train_input, a, b) 
print("Result has nan values? ", any(list(torch.isnan(train_target).detach().cpu().numpy())))

# Init the optimizer 
n_epochs    = 100
batche_size = 1
seq_length  = 512
seq_step    = 256
model     = FilterNet(seq_length, 10*seq_length, 6, batche_size, num_biquads, 1, fs)
optimizer = Adam(model.parameters(), lr=1e-1, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
criterion = FilterLoss()

# compute filter response
freqs_hz, mag_response_db, phase_response_deg = model.compute_filter_magnitude_and_phase_frequency_response(model.get_dirac(fs, 0, grad=False), fs, a, b)
target_frequency_response = torch.hstack((mag_response_db, phase_response_deg))

# Inits
start_time = time.time()    # Start timing the loop
pbar = tqdm(total=n_epochs) # Create a tqdm progress bar
loss_history = []

# data batching 
num_sequences = int(train_input.shape[0] / seq_length)

# Run training
for epoch in range(n_epochs):    
    model.train()
    device = next(model.parameters()).device
    print("\n+ Epoch : ", epoch)
    total_loss = 0
    for seq_id in range(num_sequences):
        start_idx = seq_id*seq_step
        end_idx   = seq_id*seq_step + seq_length 
        # print(seq_id, start_idx, end_idx)

        input_seq_batch  = train_input[start_idx:end_idx].unsqueeze(0).to(device)
        target_seq_batch = train_target[start_idx:end_idx].unsqueeze(0).to(device)        
        optimizer.zero_grad()

        # Compute prediction and loss
        predicted_sos = model(input_seq_batch)
        batch_loss = criterion(predicted_sos, input_seq_batch, target_seq_batch)

        predicted_sos.requires_grad_(True)
        batch_loss.requires_grad_(True)

        if debug:
            print("|-> ∇ predicted_sos                : ", predicted_sos.grad)
            print("|-> ∇ batch_loss (before backprop) : ", batch_loss.grad)
        # Backpropagation

        batch_loss.backward()
        if debug:
            print("|-> ∇ batch_loss (after backprop)  : ", batch_loss.grad)

        optimizer.step()
        total_loss += batch_loss.item()
        if debug:
            print(f"|=========> Sequence {seq_id}: Loss = {batch_loss.item():.9f}")

    # record loss
    epoch_loss = total_loss / num_sequences
    loss_history.append(epoch_loss)
    print("-"* 100)
    print(f"|=========> epoch_loss = {epoch_loss:.3f} | Loss = {epoch_loss:.3f}")

    # Update the progress bar
    #pbar.set_description(f"\nEpoch: {epoch}, Loss: {epoch_loss:.9f}\n")
    #pbar.update(1)
    scheduler.step(total_loss)
    print("*"* 100)
    if math.isnan(epoch_loss):
        break

# End timing the loop & print duration
elapsed_time = time.time() - start_time
print(f"\nOptimization loop took {elapsed_time:.2f} seconds.")

# Plot predicted filter
predicted_a = model.sos[:, 3:].detach().cpu().T.squeeze(1)
predicted_b = model.sos[:, :3].detach().cpu().T.squeeze(1)
freqs_hz, predicted_mag_response_db, predicted_phase_response_deg = model.compute_filter_magnitude_and_phase_frequency_response(model.get_dirac(fs, 0, grad=False), fs, predicted_a, predicted_b)

Output

Usually it takes some iterations for the nan values to show up but I am not sure why this happens? could this be floating point error as I am using floats and not double ?

  0%|          | 0/100 [01:39<?, ?it/s]
Result has nan values?  False
  0%|          | 0/100 [00:00<?, ?it/s]
+ Epoch :  0
> sos:  tensor([[-0.0839,  0.0890,  0.0432,  0.0235,  0.0009,  0.0239]],
       grad_fn=<AddmmBackward0>)
> y:  tensor([[ 9.1976e-01,  3.9396e-01, -5.6361e-01, -9.0263e-01,  4.0705e-02,
          8.9212e-01,  3.7140e-01, -6.8597e-01, -5.4972e-01,  5.8698e-01,
          5.9120e-01, -5.8381e-01, -4.9494e-01,  7.6632e-01,  9.3482e-02,
         -8.1863e-01,  3.3197e-01,  5.9930e-01, -8.4578e-01,  1.7147e-01,
          7.3452e-01, -8.1758e-01,  3.3193e-01,  4.1554e-01, -8.9562e-01,
          8.3262e-01, -5.6126e-01, -5.4229e-03,  5.4372e-01, -8.6507e-01,
          9.2262e-01, -8.0890e-01,  7.3650e-01, -6.0981e-01,  3.9738e-01,
         -3.7726e-01,  3.1291e-01, -2.9066e-01,  3.5591e-01, -4.2745e-01,
          5.6797e-01, -7.0850e-01,  8.5734e-01, -7.9908e-01,  7.4616e-01,
         -3.5294e-01, -1.1078e-01,  6.8103e-01, -8.6427e-01,  6.2385e-01,
          1.0429e-01, -6.5078e-01,  7.7670e-01,  3.6204e-02, -7.6227e-01,
          5.9776e-01,  3.4664e-01, -8.7046e-01, -1.6191e-01,  8.9873e-01,
          1.4116e-01, -8.5808e-01, -3.7914e-01,  6.2105e-01,  7.6325e-01,
         -1.3730e-01, -8.1782e-01, -6.7562e-01,  9.2269e-02,  6.3184e-01,
          8.7427e-01,  5.4773e-01,  4.0460e-02, -3.6624e-01, -7.3464e-01,
         -8.2703e-01, -9.1126e-01, -8.5777e-01, -8.4543e-01, -8.6385e-01,
         -7.9900e-01, -8.1887e-01, -8.7094e-01, -7.7519e-01, -7.2416e-01,
         -4.4006e-01,  1.6800e-01,  5.6366e-01,  9.3301e-01,  7.5224e-01,
         -7.3612e-02, -7.7336e-01, -6.8303e-01,  2.4916e-01,  9.2302e-01,
          1.9588e-01, -8.5115e-01, -1.6773e-01,  9.0949e-01,  3.9266e-02,
         -8.1771e-01,  6.6815e-01,  3.2386e-01, -8.7820e-01,  7.4271e-01,
         -1.3039e-01, -4.7525e-01,  7.6554e-01, -9.1396e-01,  7.4807e-01,
         -6.7082e-01,  6.0845e-01, -6.1776e-01,  6.2066e-01, -7.3603e-01,
          7.8326e-01, -8.7389e-01,  8.3079e-01, -4.7536e-01, -9.8151e-02,
          7.6845e-01, -8.5406e-01,  1.8633e-01,  7.7706e-01, -7.8508e-01,
         -2.8982e-01,  8.7603e-01,  2.4262e-01, -7.6036e-01, -6.1301e-01,
          3.6537e-01,  8.2418e-01,  6.9165e-01, -1.0009e-01, -6.3840e-01,
         -8.3986e-01, -8.6909e-01, -7.6598e-01, -6.7613e-01, -6.1851e-01,
         -6.4054e-01, -7.3947e-01, -7.8320e-01, -9.1459e-01, -6.4904e-01,
         -4.5626e-02,  5.9469e-01,  9.0986e-01,  3.2173e-01, -6.2907e-01,
         -6.3078e-01,  5.6355e-01,  6.8357e-01, -7.2929e-01, -1.2930e-01,
          8.6535e-01, -6.8470e-01,  7.9903e-02,  5.0202e-01, -7.1384e-01,
          9.1437e-01, -8.7692e-01,  8.9491e-01, -8.4579e-01,  8.3886e-01,
         -6.4914e-01,  3.8716e-01,  2.9820e-01, -7.9972e-01,  8.0648e-01,
          1.4961e-01, -8.8522e-01,  2.0206e-01,  8.7826e-01,  7.5181e-02,
         -7.8768e-01, -7.8199e-01, -6.8105e-02,  4.8316e-01,  7.6755e-01,
          7.9499e-01,  8.5816e-01,  9.0855e-01,  8.7831e-01,  8.1642e-01,
          5.8383e-01,  8.7610e-02, -5.6985e-01, -9.0071e-01, -2.2753e-01,
          8.7734e-01,  2.6020e-01, -9.3027e-01,  6.3194e-02,  7.5747e-01,
         -8.9864e-01,  5.8946e-01, -1.1875e-01, -1.3843e-01,  1.8369e-01,
         -1.8120e-01, -5.8798e-02,  4.5497e-01, -8.4947e-01,  8.2299e-01,
         -9.0557e-02, -7.7665e-01,  4.3142e-01,  7.3419e-01, -2.1777e-01,
         -8.8053e-01, -7.3262e-01, -3.1814e-01,  1.2663e-01,  4.1220e-01,
          4.4359e-01,  2.4527e-01,  1.1442e-02, -5.4809e-01, -8.5944e-01,
         -5.1928e-01,  5.1216e-01,  7.7065e-01, -5.3413e-01, -5.1005e-01,
          9.1581e-01, -6.6810e-01,  2.8103e-01,  1.0429e-01, -1.2456e-01,
         -2.0051e-02,  3.2061e-01, -7.7232e-01,  8.9089e-01, -9.2229e-02,
         -7.7709e-01,  1.7892e-01,  8.9708e-01,  4.3318e-01, -2.9382e-01,
         -6.1273e-01, -8.4722e-01, -8.0991e-01, -7.6246e-01, -4.2533e-01,
          1.5898e-01,  6.9969e-01,  5.8429e-01, -4.9997e-01, -5.2366e-01,
          8.8733e-01, -4.8413e-01, -5.6845e-02,  3.2823e-01, -2.7535e-01,
          5.8646e-02,  3.8791e-01, -8.5948e-01,  6.2797e-01,  5.4898e-01,
         -7.1709e-01, -7.8377e-01, -7.7660e-02,  4.1973e-01,  6.7400e-01,
          5.9089e-01,  3.3314e-01, -2.3317e-01, -7.5179e-01, -7.0252e-01,
          4.7355e-01,  6.5309e-01, -8.6008e-01,  5.7294e-01, -1.7956e-01,
          1.3361e-01, -2.9017e-01,  6.4041e-01, -7.7170e-01,  8.7169e-02,
          8.8803e-01,  7.7755e-02, -6.8601e-01, -9.1543e-01, -8.5256e-01,
         -9.0660e-01, -8.4668e-01, -4.3815e-01,  5.3448e-01,  8.5567e-01,
         -5.5684e-01, -3.5364e-01,  7.4125e-01, -8.5456e-01,  8.7243e-01,
         -5.7543e-01, -1.1064e-01,  8.2345e-01, -4.2815e-01, -9.5575e-01,
         -5.3964e-01, -1.2472e-01, -2.6998e-02, -3.1720e-01, -8.1200e-01,
         -7.5506e-01,  3.3415e-01,  7.3341e-01, -8.5793e-01,  6.4847e-01,
         -5.6165e-01,  7.2199e-01, -9.2167e-01,  2.5240e-01,  8.1152e-01,
         -1.7183e-02, -7.0769e-01, -7.6791e-01, -8.2818e-01, -5.9550e-01,
          2.5480e-01,  8.9480e-01, -2.4253e-01, -5.7866e-01,  7.8063e-01,
         -7.6692e-01,  6.2261e-01,  9.4085e-02, -9.0210e-01, -1.4296e-04,
          6.2649e-01,  9.1617e-01,  8.8035e-01,  6.7028e-01, -4.5758e-02,
         -8.4817e-01,  1.6099e-01,  4.0293e-01, -7.1112e-01,  4.4229e-01,
          9.8246e-02, -8.2635e-01,  1.1687e-01,  7.1600e-01,  8.9065e-01,
          8.5364e-01,  6.9439e-01, -1.7529e-01, -8.1223e-01,  6.4583e-01,
         -2.6679e-01,  3.2047e-01, -7.2581e-01,  8.0636e-01,  4.9336e-01,
         -2.0059e-01, -6.0420e-01, -4.9924e-01,  1.3116e-01,  9.0643e-01,
         -4.6592e-02, -5.2387e-01,  7.1004e-01, -4.2669e-01, -5.1694e-01,
          8.1108e-01,  7.5543e-01,  5.2991e-01,  6.4293e-01,  8.5856e-01,
          1.3298e-01, -8.7074e-01,  8.7304e-01, -8.7711e-01,  8.4589e-01,
          2.1442e-01, -7.4926e-01, -9.1862e-01, -8.8200e-01, -6.1612e-01,
          4.2480e-01,  4.5994e-01, -8.4408e-01,  7.6445e-01, -1.5454e-01,
         -8.0695e-01, -3.2293e-01,  4.7710e-02, -2.8813e-01, -8.6099e-01,
         -2.2353e-02,  5.6905e-01, -6.0185e-01,  1.0927e-01,  7.9795e-01,
          1.8758e-01, -7.9930e-02,  3.4110e-01,  9.4077e-01, -1.3996e-01,
         -3.0946e-01,  9.9684e-02,  6.7823e-01, -6.0815e-01, -8.7170e-01,
         -7.4614e-01, -3.8429e-01,  7.7112e-01, -2.5681e-01,  3.2833e-01,
         -9.2135e-01,  3.1816e-01,  8.3094e-01,  7.2145e-01,  2.3642e-01,
         -9.5059e-01,  7.0914e-01, -7.4771e-01,  7.0298e-01,  7.1790e-01,
          3.7567e-01,  6.8837e-01,  8.7450e-01, -8.0358e-01,  6.4598e-01,
         -9.2151e-01, -5.6793e-02,  5.2370e-01,  3.3493e-01, -5.9378e-01,
         -2.9835e-01,  3.8455e-01,  1.1808e-01, -8.4276e-01, -7.7548e-01,
         -8.7156e-01, -2.5093e-01,  8.6722e-01, -8.1676e-01,  1.6701e-01,
          8.8189e-01,  8.6866e-01,  7.2987e-01, -6.6552e-01,  5.0933e-01,
         -8.2186e-01,  2.9475e-01,  6.9221e-01,  4.3710e-01, -8.1872e-01,
          3.5778e-01, -7.1707e-01,  5.5313e-01,  8.7464e-01,  6.8380e-01,
         -5.5376e-01,  1.6429e-01, -5.9098e-01,  6.8535e-01,  8.9860e-01,
          4.5806e-01, -8.7448e-01,  7.2048e-01, -7.5579e-01, -6.5157e-01,
         -6.1910e-01, -7.8635e-01,  8.1934e-01, -9.2723e-01, -5.7567e-02,
          2.9977e-01, -3.6001e-01, -6.1193e-01,  5.3612e-01,  4.9506e-01,
         -5.5679e-02,  4.4219e-01,  6.9989e-01, -7.8661e-01, -1.3486e-01,
          4.1566e-01,  1.2228e-03, -9.1840e-01,  8.8469e-01, -7.8940e-02,
         -4.6307e-01,  1.8830e-01,  6.4715e-01, -4.4903e-01, -7.8729e-01,
         -6.6936e-01, -8.6606e-01,  9.2661e-01, -6.8972e-01, -9.2163e-01,
         -8.7048e-01,  4.2997e-01, -4.8522e-01,  7.3133e-01,  8.2782e-01,
         -1.1818e-01, -1.1519e-01, -6.7208e-01, -3.6443e-01, -8.3441e-01,
          4.1011e-01, -9.5316e-01, -3.1651e-01, -8.1934e-01,  4.0803e-02,
         -5.6391e-01,  2.4609e-01]])
> predicted_y:  tensor([[-0.2572, -0.3931, -0.6574, -0.7525, -0.9772, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000]],
       grad_fn=<ViewBackward0>)
> loss:  tensor(1.2340, grad_fn=<MseLossBackward0>)
> sos:  tensor([[nan, nan, nan, nan, nan, nan]], grad_fn=<AddmmBackward0>)
> y:  tensor([[ 3.8791e-01, -8.5948e-01,  6.2797e-01,  5.4898e-01, -7.1709e-01,
         -7.8377e-01, -7.7660e-02,  4.1973e-01,  6.7400e-01,  5.9089e-01,
          3.3314e-01, -2.3317e-01, -7.5179e-01, -7.0252e-01,  4.7355e-01,
          6.5309e-01, -8.6008e-01,  5.7294e-01, -1.7956e-01,  1.3361e-01,
         -2.9017e-01,  6.4041e-01, -7.7170e-01,  8.7169e-02,  8.8803e-01,
          7.7755e-02, -6.8601e-01, -9.1543e-01, -8.5256e-01, -9.0660e-01,
         -8.4668e-01, -4.3815e-01,  5.3448e-01,  8.5567e-01, -5.5684e-01,
         -3.5364e-01,  7.4125e-01, -8.5456e-01,  8.7243e-01, -5.7543e-01,
         -1.1064e-01,  8.2345e-01, -4.2815e-01, -9.5575e-01, -5.3964e-01,
         -1.2472e-01, -2.6998e-02, -3.1720e-01, -8.1200e-01, -7.5506e-01,
          3.3415e-01,  7.3341e-01, -8.5793e-01,  6.4847e-01, -5.6165e-01,
          7.2199e-01, -9.2167e-01,  2.5240e-01,  8.1152e-01, -1.7183e-02,
         -7.0769e-01, -7.6791e-01, -8.2818e-01, -5.9550e-01,  2.5480e-01,
          8.9480e-01, -2.4253e-01, -5.7866e-01,  7.8063e-01, -7.6692e-01,
          6.2261e-01,  9.4085e-02, -9.0210e-01, -1.4296e-04,  6.2649e-01,
          9.1617e-01,  8.8035e-01,  6.7028e-01, -4.5758e-02, -8.4817e-01,
          1.6099e-01,  4.0293e-01, -7.1112e-01,  4.4229e-01,  9.8246e-02,
         -8.2635e-01,  1.1687e-01,  7.1600e-01,  8.9065e-01,  8.5364e-01,
          6.9439e-01, -1.7529e-01, -8.1223e-01,  6.4583e-01, -2.6679e-01,
          3.2047e-01, -7.2581e-01,  8.0636e-01,  4.9336e-01, -2.0059e-01,
         -6.0420e-01, -4.9924e-01,  1.3116e-01,  9.0643e-01, -4.6592e-02,
         -5.2387e-01,  7.1004e-01, -4.2669e-01, -5.1694e-01,  8.1108e-01,
          7.5543e-01,  5.2991e-01,  6.4293e-01,  8.5856e-01,  1.3298e-01,
         -8.7074e-01,  8.7304e-01, -8.7711e-01,  8.4589e-01,  2.1442e-01,
         -7.4926e-01, -9.1862e-01, -8.8200e-01, -6.1612e-01,  4.2480e-01,
          4.5994e-01, -8.4408e-01,  7.6445e-01, -1.5454e-01, -8.0695e-01,
         -3.2293e-01,  4.7710e-02, -2.8813e-01, -8.6099e-01, -2.2353e-02,
          5.6905e-01, -6.0185e-01,  1.0927e-01,  7.9795e-01,  1.8758e-01,
         -7.9930e-02,  3.4110e-01,  9.4077e-01, -1.3996e-01, -3.0946e-01,
          9.9684e-02,  6.7823e-01, -6.0815e-01, -8.7170e-01, -7.4614e-01,
         -3.8429e-01,  7.7112e-01, -2.5681e-01,  3.2833e-01, -9.2135e-01,
          3.1816e-01,  8.3094e-01,  7.2145e-01,  2.3642e-01, -9.5059e-01,
          7.0914e-01, -7.4771e-01,  7.0298e-01,  7.1790e-01,  3.7567e-01,
          6.8837e-01,  8.7450e-01, -8.0358e-01,  6.4598e-01, -9.2151e-01,
         -5.6793e-02,  5.2370e-01,  3.3493e-01, -5.9378e-01, -2.9835e-01,
          3.8455e-01,  1.1808e-01, -8.4276e-01, -7.7548e-01, -8.7156e-01,
         -2.5093e-01,  8.6722e-01, -8.1676e-01,  1.6701e-01,  8.8189e-01,
          8.6866e-01,  7.2987e-01, -6.6552e-01,  5.0933e-01, -8.2186e-01,
          2.9475e-01,  6.9221e-01,  4.3710e-01, -8.1872e-01,  3.5778e-01,
         -7.1707e-01,  5.5313e-01,  8.7464e-01,  6.8380e-01, -5.5376e-01,
          1.6429e-01, -5.9098e-01,  6.8535e-01,  8.9860e-01,  4.5806e-01,
         -8.7448e-01,  7.2048e-01, -7.5579e-01, -6.5157e-01, -6.1910e-01,
         -7.8635e-01,  8.1934e-01, -9.2723e-01, -5.7567e-02,  2.9977e-01,
         -3.6001e-01, -6.1193e-01,  5.3612e-01,  4.9506e-01, -5.5679e-02,
          4.4219e-01,  6.9989e-01, -7.8661e-01, -1.3486e-01,  4.1566e-01,
          1.2228e-03, -9.1840e-01,  8.8469e-01, -7.8940e-02, -4.6307e-01,
          1.8830e-01,  6.4715e-01, -4.4903e-01, -7.8729e-01, -6.6936e-01,
         -8.6606e-01,  9.2661e-01, -6.8972e-01, -9.2163e-01, -8.7048e-01,
          4.2997e-01, -4.8522e-01,  7.3133e-01,  8.2782e-01, -1.1818e-01,
         -1.1519e-01, -6.7208e-01, -3.6443e-01, -8.3441e-01,  4.1011e-01,
         -9.5316e-01, -3.1651e-01, -8.1934e-01,  4.0803e-02, -5.6391e-01,
          2.4609e-01, -2.3960e-01, -5.5470e-01,  3.6925e-02,  7.1057e-01,
          1.9385e-01, -7.5652e-01,  3.0372e-01,  7.2236e-01, -5.1434e-02,
         -4.5202e-01, -4.6889e-01, -2.9327e-01, -8.4527e-01,  9.5000e-01,
          2.9422e-01,  4.6551e-01,  4.8772e-01,  1.5231e-01, -3.8959e-01,
          5.1982e-01, -2.3434e-01,  7.8953e-01,  7.5034e-01, -7.0884e-01,
          9.3019e-01,  8.5835e-01,  5.0013e-01, -4.5650e-01, -9.2527e-01,
         -6.3914e-01,  8.4755e-01,  5.3958e-02,  2.9492e-01,  4.8161e-01,
          4.0986e-01,  2.6331e-01,  6.6485e-01, -1.2293e-01, -4.1544e-01,
          7.0540e-01, -8.5294e-01, -8.7304e-01, -9.3342e-02, -3.5215e-01,
          8.9565e-02, -9.9046e-01,  6.1068e-01,  8.8629e-01, -5.3449e-01,
          8.3798e-01,  8.6538e-01, -5.1332e-01,  8.4274e-01,  7.6758e-01,
         -8.4532e-01, -1.5498e-01, -7.9410e-01,  7.7259e-01,  3.7390e-01,
          8.4366e-01, -3.6603e-01, -3.4202e-01,  8.1280e-01,  3.6973e-01,
          8.3627e-01, -8.3431e-01, -8.0810e-01, -2.2454e-02, -8.7870e-01,
         -9.0915e-01,  7.4053e-01,  8.5610e-01, -3.0779e-01,  7.9113e-01,
          4.4560e-01, -2.1048e-01, -2.7016e-01,  7.2329e-01,  6.3043e-01,
          6.9954e-01,  2.8789e-01,  6.5021e-01, -8.6112e-01, -7.1632e-01,
          3.5859e-01, -4.6140e-01,  8.2440e-01, -1.9628e-01,  7.0427e-01,
         -8.5660e-01, -4.4009e-01, -6.3181e-02, -2.3422e-01,  2.6736e-01,
          1.6188e-01,  9.7502e-02,  4.7115e-01, -7.3173e-01, -8.8227e-01,
          5.1571e-01, -1.4965e-01,  7.9881e-01,  8.4567e-01, -3.0035e-01,
          4.5106e-01, -8.4508e-01, -8.3372e-01,  3.9501e-01,  5.2049e-02,
          2.1586e-01,  5.2867e-01, -5.7484e-01, -4.4536e-01,  8.3223e-02,
         -4.7484e-01,  8.4434e-01,  5.1355e-01,  6.7212e-01,  3.5271e-01,
          8.0240e-01, -8.0637e-01, -7.5211e-01,  8.1599e-01,  7.9293e-01,
          8.3560e-02,  7.8035e-01,  6.8556e-01, -4.3063e-01, -7.0244e-01,
          7.5738e-01, -8.4581e-02,  7.2929e-01, -9.1347e-01, -7.1973e-01,
         -1.6491e-01, -8.2702e-01, -9.3891e-01,  9.2490e-01, -1.2566e-02,
          7.4873e-01, -8.1891e-01, -4.7751e-01, -8.2838e-01,  5.7529e-01,
          8.2638e-01, -4.0364e-01,  6.7227e-01,  3.3408e-01,  9.5295e-01,
         -7.5554e-01, -8.3567e-01, -6.2258e-01,  7.2503e-01,  5.9781e-01,
          8.2340e-01, -2.6083e-01,  7.8459e-01,  5.1306e-01,  8.2925e-01,
         -8.7183e-01,  1.0709e-01,  2.2379e-01, -9.0614e-01,  8.1067e-01,
         -4.2589e-01, -5.7031e-01,  5.4186e-01, -2.7102e-01,  8.8015e-01,
          5.6653e-01,  7.8097e-01,  6.2334e-02, -7.7299e-02, -8.0579e-01,
         -4.1129e-01, -7.3453e-01, -4.5815e-01,  6.7041e-01, -2.6993e-01,
         -8.5183e-01, -6.7688e-01, -8.7236e-01, -1.7042e-01,  8.2323e-01,
         -8.0824e-01,  1.6807e-01,  8.2084e-01,  4.3708e-01,  3.5883e-01,
          7.7326e-01,  6.6365e-01, -7.3685e-01,  1.0451e-01,  2.6754e-01,
         -8.6266e-02, -4.3746e-01,  8.5332e-01, -2.6818e-01, -8.4285e-01,
          1.6539e-02,  6.2070e-01,  8.7858e-01,  6.1471e-01,  1.9753e-01,
         -2.4663e-01, -6.2930e-01, -8.8665e-01, -7.9760e-01, -5.1274e-01,
          1.9819e-01,  8.2065e-01,  5.8468e-01, -6.9930e-01, -5.0495e-01,
          1.0000e+00, -8.4754e-01,  7.6522e-01, -9.7491e-01,  5.4828e-01,
          7.5403e-01,  9.5173e-02, -1.0257e-01,  4.2208e-01,  8.9789e-01,
         -6.4619e-01,  6.2701e-01, -8.4820e-01, -2.3291e-01,  4.6961e-02,
         -5.0937e-01, -4.6640e-01,  4.6904e-01,  5.6395e-01,  1.4434e-03,
          3.5741e-01,  4.8187e-01, -2.0030e-01, -9.2236e-01, -8.6389e-01,
          4.4110e-01, -6.7375e-01,  3.2400e-01, -2.8044e-01, -3.0323e-01,
         -6.8602e-01, -5.8022e-01, -4.0448e-01, -3.2494e-01, -9.8747e-02,
         -8.1863e-01,  2.0085e-01,  1.8025e-01, -6.5746e-01, -5.6432e-01,
         -8.9734e-01,  8.5415e-01]])
> predicted_y:  tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan]], grad_fn=<ViewBackward0>)
> loss:  tensor(nan, grad_fn=<MseLossBackward0>)
----------------------------------------------------------------------------------------------------
|=========> epoch_loss = nan | Loss = nan
****************************************************************************************************

Optimization loop took 0.05 seconds.

On the other hand, when I use a smaller sequence_length, the error either does not happen or takes longer to happen. However in all case my gradient is None, which I think making the code does not converge/ work as it is supposed to

Versions

Collecting environment information...
PyTorch version: 2.3.0+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 Pro
GCC version: (Rev2, Built by MSYS2 project) 9.3.0
Clang version: Could not collect
CMake version: version 3.17.1
Libc version: N/A

Python version: 3.9.13 (main, Aug 25 2022, 23:51:50) [MSC v.1916 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.19045-SP0
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture=9
CurrentClockSpeed=2803
DeviceID=CPU0
Family=198
L2CacheSize=5120
L2CacheSpeed=
Manufacturer=GenuineIntel
MaxClockSpeed=2803
Name=11th Gen Intel(R) Core(TM) i7-1165G7 @ 2.80GHz
ProcessorType=3
Revision=

Versions of relevant libraries:
[pip3] flake8==4.0.1
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.25.2
[pip3] numpydoc==1.4.0
[pip3] onnx==1.11.0
[pip3] onnxruntime==1.15.0
[pip3] optree==0.11.0
[pip3] pytorch-lightning==2.1.3
[pip3] torch==2.3.0
[pip3] torchaudio==2.3.0
[pip3] torchmetrics==1.3.0
[pip3] torchvision==0.16.2
[conda] blas                      1.0                         mkl  
[conda] mkl                       2021.4.0                 pypi_0    pypi
[conda] mkl-service               2.4.0            py39h2bbff1b_0  
[conda] mkl_fft                   1.3.1            py39h277e83a_0  
[conda] mkl_random                1.2.2            py39hf11a4ad_0  
[conda] numpy                     1.25.2                   pypi_0    pypi
[conda] numpydoc                  1.4.0            py39haa95532_0  
[conda] optree                    0.11.0                   pypi_0    pypi
[conda] pytorch-lightning         2.1.3                    pypi_0    pypi
[conda] torch                     2.1.2                    pypi_0    pypi
[conda] torchaudio                2.1.2                    pypi_0    pypi
[conda] torchmetrics              1.3.0                    pypi_0    pypi
[conda] torchvision               0.16.2                   pypi_0    pypi
yoyololicon commented 3 days ago

@SuperKogito A few suggestions:

  1. Try lowering the learning rate. 0.1 is a bit too high for Adam to work.
  2. Apply some parameterisation to a coefficient to ensure stability. Some reference: https://arxiv.org/abs/2103.08709 (they also used biquad just like your sos params)
SuperKogito commented 7 hours ago

@yoyololicon thank you for your response, with low learning rate, stability constraint and bit more fine tuning, I managed to have a stable run and a more smoothly decaying loss.

def forward(self, x):
    for layer in self.layers:
        x = layer(x)

    sos = x
    # make sure that a0 = ones 
    a0 = torch.ones_like(sos[:, :, 0])

    # Enforce stability constraint based on https://arxiv.org/pdf/2103.08709 eq. (4, 5)
    a1 = 2 * torch.tanh(sos[:, :, 3])
    a2 = ((2 - torch.abs(a1)) * torch.tanh(sos[:, :, 4]) + torch.abs(a1)) * 0.5
    sos = torch.stack([sos[:, :, 0], sos[:, :, 1], sos[:, :, 2], a0, a1, a2], dim=-1)

    return sos

for biquads this seems to work (somewhat). I have one more inquiry regarding the lfilter() function; in order to use it to compute the output of a cascade of biquads am I supposed to use it in a loop as follows:

class BiquadsCascade(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(x, a, b):
        num_biquads = a.shape[0]
        for biquad_idx in range(num_biquads):
            x = lfilter(x, a[biquad_idx,:], b[biquad_idx,:])
        return x       

since lfilter expects a filter and not a cascade of biquads. Is there a better way to do this? switching between the cascaded coefficients and high order coefficients using scipy does not seem to a be an option for me since that implementation is probably not differentiable, hence it will hinder the back-propagation

yoyololicon commented 7 hours ago

You can reference this implementation to combine multiple biquads into one high-order filter. https://github.com/yoyololicon/golf/blob/52f50e7341f769d49e6bddbbe887c149c2b9a413/models/utils.py#L444-L460

Note: The above code only handles the poles (denominator of the transfer function), but the function can also be applied to zeros.

SuperKogito commented 5 hours ago

This looks very helpful. Thank you :)