KinWaiCheuk / nnAudio

Audio processing by using pytorch 1D convolution network
MIT License
1.03k stars 89 forks source link

Spectrograms not updating well at low frequency bins #115

Open arachid1 opened 2 years ago

arachid1 commented 2 years ago

Hello, Thanks for putting in place a really useful library!

I'm working on the pneumonia detection problem. My dataset is super imbalanced, with 2000+ non-pneumonia cases and 142 cases, but I decided to stick with 142 cases of each label to keep the dataset balanced.

I am trying to apply the STFT layer in the following model:

image

with the following parameters:

self.spec_layer = Spectrogram.STFT(n_fft=256, hop_length=128, sr=8000, trainable=True, output_format="Magnitude")

Now, I'm observing some modifications of the spectrograms as it trains, but it seems like the trained spectrogram mainly gets updated at the higher frequency bins. It should be the low-frequency bins that inform the neural network of decision-making, since lung sounds are of the range 0-4000Hz and I sample at 8000 Hz. Here is a spectrogram of a pneumonia sample before training:

outputs__orig_index_9_label_1

and here its updated version at, respectively, epochs 10, 50, and 150:

outputs___9_label_1_epoch_10

outputs___9_label_1_epoch_50

outputs___9_label_1_epoch_140

Since it's really hard to visualize, I generate a difference map ( = trained spectrogram at given epoch - original untrained spectrogram). Here are the difference maps at, respectively, epoch 10, 50 and 150:

diff___9_label_1_epoch_10

diff___9_label_1_epoch_50

diff___9_label_1_epoch_140

It's difficult to see but there are some slight modifications of the lower frequency bins 0-24, only it's little, and barely any for bins 0-12.

Some of the training parameters are

parameters.lr = 1e-4 parameters.n_epochs = 150 parameters.batch_size = 32 parameters.audio_length = 5

I use nnAudio == 0.2.6.

KinWaiCheuk commented 2 years ago

Hi @arachid1. Thanks for using nnAudio! The STFT kernel updates are gradient driven (and gradient are calculated from the loss function that you used). Just like other neural network layers in PyTorch, sometimes it is quite difficult to understand why the model decided to update some parameters but not the others.

However, since you already have the prior knowledge that most of the important information are in the low frequencies regions, is it possible for you to set trainable=True and then freeze all those high frequency bins? If you are using PyTorch Lightning, one simple example would be like this.

import torch
import torch.nn as nn
from nnAudio import Spectrogram
import pytorch_lightning as pl
from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_split
import torch.optim as optim
import matplotlib.pyplot as plt

X, Y = make_blobs(1000,44100,centers=10, cluster_std=10)
X_train, X_test, y_train, y_test = train_test_split(X,Y, test_size=0.2, random_state=0)

trainset = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(),torch.from_numpy(y_train))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,shuffle=True, num_workers=2)

class Model(pl.LightningModule):
    def __init__(self):
        super(Model, self).__init__()
        self.stft_layer = Spectrogram.STFT(trainable=True, output_format='Magnitude')
        self.classifier = nn.Linear(1025*87,10)

    def on_after_backward(self):
        # freeze bins 20-1025
        # i.e. only updates bins 0-19
        self.stft_layer.wsin.grad[20:] = 0
        self.stft_layer.wcos.grad[20:] = 0

    def forward(self, x):
        x = self.stft_layer(x)
        x = self.classifier(x.flatten(1))
        return x

    def training_step(self, batch, batch_idx):
        pred = self(batch[0])
        loss = torch.nn.functional.cross_entropy(pred, batch[1])

        return torch.nn.functional.cross_entropy(pred, batch[1])

    def configure_optimizers(self):
        r"""Configure optimizer."""
        return optim.Adam(self.parameters())

model = Model()

original_weight = model.stft_layer.wsin.detach()

model.classifier.weight.grad

trainer = pl.Trainer(max_epochs=2, gpus=1)

trainer.fit(model, trainloader)

changed_weight = model.stft_layer.wsin.detach()

# check if bin 0-20 are still the same after training
print(torch.equal(original_weight[:20],changed_weight[:20]))
# It should return False

# check if bin 20-1025 are still the same after training
print(torch.equal(original_weight[20:],changed_weight[20:]))
# It should return True

If you are using plain PyTorch, you can set the gradients for higher bins to 0 after loss.backward() and before optimizer.step().

This is just my suggestion, I am not sure if it works or not. But if it works, I think it is worth an update in nnAudio to let the users to control which STFT bins to be trained. Please feel free to update me which the latest result!

arachid1 commented 2 years ago

Thanks for your suggestion!

Unfortunately, it doesn't seem to make a great deal of difference, as illustrated with the same spectrogram before and after training below:

outputs__orig_index_7_label_0 outputs__orig_index_7_label_tensor( 0 , device='cuda:0')

but it does work, could be a very useful tool, and helped me investigate a lot.

kernels_10 kernels_25 kernels_50 kernels_100

and here are the respective weights:

weights_10_wsin_last_epoch weights_25_wsin_last_epoch weights_50_wsin_last_epoch weights_100_wsin_last_epoch

My guess is that the kernels extract a lot more from lower frequencies when running STFT because my data is richer in those frequencies, so if those get backpropagated more, it makes sense that the kernels change more too. If you have more thoughts on this, please let me know.

It still doesn't explain why the higher frequencies change a lot but that could be due to a lot of things, but this is a great step towards explainable ML.

Another way to frame the odd finding:

KinWaiCheuk commented 2 years ago

Thanks for your update! Regarding the reason why the higher bins backpropagate more than the lower bins despite richer information in lower bin, I also have no idea. But it reminds me of a paper that our team saw a few weeks ago, which is about restricted trainable v.s. unrestricted trainable front-ends. https://arxiv.org/pdf/2109.02774.pdf

Although in their case, they focused on Mel spectrograms rather than STFT, I think their idea still applies to our case.

They found out that unrestricted kernel training is harmful to the model performance. nnAudio is unrestricted since we do not impose any constraints on how the kernel parameters update. While in their paper, they imposed the shape constraints (either triangular, rectangular, gaussian) on the kernels.

So to control the (40,241) or (num_mels, n_fft//2+1) mel filter basis, they use only 80 parameters (40 band widths, and 40 centrals) to control the triangular filter locations, as oppose to nnAudio which uses 40*241 parameters to fully control everything. To me, I think if we provide too much degree of freedom, backprop might not be able to do a great job.

I think we might also want to apply some sorts of constraints to the STFT kernels to get a better result? I haven't tried this ideaa yet, but I agree that understanding this is a great step towards explainable ML! I am looking forward to see your new findings!