harritaylor / torchvggish

Pytorch port of Google Research's VGGish model used for extracting audio features.
Apache License 2.0
377 stars 68 forks source link

how to feed model batch inputs ? #21

Open ChenHuaYou opened 3 years ago

ChenHuaYou commented 3 years ago

hi, how can i feed the model batch inputs ? i just know how to feed one audio to the model,but if i want to feed batch?can you tell me ? thanks.

singhal2 commented 3 years ago

+1

Chevalier1024 commented 3 years ago

hi bros, I feed model batch inputs through the following steps.

  1. just save the results of vggish_input.wavfile_to_examples() on your dataset.
  2. finetune model

    class FintuneModel(nn.Module):
    def __init__(self):
        super(FintuneModel, self).__init__()
        urls = {
            'vggish': "https://github.com/harritaylor/torchvggish/releases/download/v0.1/vggish-10086976.pth"
        }
        self.pretrain = vggish.VGGish(urls, preprocess=False, postprocess=False)
        self.classifier = classifier()
    
    def forward(self, x):
        """
        :param x: [bs, num_frames, 96, 64]
        :return:
        """
        bs, num_frames, _, _ = x.size()
        x = x.view(bs*num_frames, 1, x.size(2), x.size(3))
        x = self.pretrain(x) # [bs*num_frames, 128]
        x = x.view(bs, x.size(1), num_frames)
        x = self.classifier(x)
        return x
ifeelagood commented 1 year ago

@Chevalier1024 so each batch has a different number of samples?