harritaylor / torchvggish

Pytorch port of Google Research's VGGish model used for extracting audio features.
Apache License 2.0
377 stars 68 forks source link

How to use this code for training on my own dataset? #8

Closed vtddggg closed 4 years ago

vtddggg commented 4 years ago

Hello, I have some .wav files, and I want to train the classification model on my own datasets.

How can I use this code? Extract embeddings and train a sequence model? Is it possible to finetune VGGish feature extractor when training the classifier?

Appreciate for any advise. Thank you !!

harritaylor commented 4 years ago

Hi, you can do all of those things.

Extract embeddings and train a sequence model?


from torchvggish import vggish, vggish_input

Initialise model and download weights

embedding_model = vggish() embedding_model.eval()

example = vggish_input.wavfile_to_examples("example.wav") embeddings = embedding_model.forward(example)

This will result in `embeddings`, a torch tensor with size`[n,128]`, where `n` is roughly the amount of seconds in "example.wav". You can use these to train a sequence model.

As long as you do `vggish_input.wavfile_to_examples()` as your preprocessing step, you're good to go.

> Is it possible to finetune VGGish feature extractor when training the classifier?

Yes, you can use VGGish as part of a larger model by initialising and *not* setting `eval()` mode. 
So (I believe this is right, but I will have to double check later. This is just to get the gist)

```python
import torch
import torch.nn as nn
from torchvggish import vggish

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.linear = nn.Linear(128, 10)
    def forward(self, x):
        out = self.linear(x)
        return out

class VGGishClassifier(nn.Module):
    def __init__(self):
        super(VGGishClassifier, self).__init__()
        self.embed = vggish()
        self.classifier = MLP()
    def forward(self, x):
        embedding = self.embed(x)
        out = self.classifier(embedding)
        return out

model = VGGishClassifier()
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

... etc

Hopefully this makes sense!