YuanGongND / ast

Code for the Interspeech 2021 paper "AST: Audio Spectrogram Transformer".
BSD 3-Clause "New" or "Revised" License
1.07k stars 205 forks source link

single aduio inference for ast_model #19

Closed JeffC0628 closed 2 years ago

JeffC0628 commented 2 years ago

hi, yuan: I have written a pretty simple script to verify the tags of the single wave, but I got the result it seems not right, could you help to point the mistake?

  import os
  import sys
  import csv

  import numpy as np
  import torch
  import torchaudio
  from src.models import ASTModel
  torchaudio.set_audio_backend("soundfile")       # switch backend
  basepath = os.path.dirname(os.path.dirname(sys.path[0]))
  sys.path.append(basepath)

  # download pretrained model in this directory
  os.environ['TORCH_HOME'] = '../pretrained_models'

  def make_features(wav_name, mel_bins, target_length=1024):
      waveform, sr = torchaudio.load(wav_name)

      fbank = torchaudio.compliance.kaldi.fbank(
          waveform, htk_compat=True, sample_frequency=sr, use_energy=False,
          window_type='hanning', num_mel_bins=mel_bins, dither=0.0,
          frame_shift=10)

      n_frames = fbank.shape[0]
      p = target_length - n_frames
      # cut and pad
      if p > 0:
          m = torch.nn.ZeroPad2d((0, 0, 0, p))
          fbank = m(fbank)
      elif p < 0:
          fbank = fbank[0:target_length, :]

      return fbank

  def load_label(label_csv):
      # Load label
      with open(label_csv, 'r') as f:
          reader = csv.reader(f, delimiter=',')
          lines = list(reader)

      labels = []
      ids = []  # Each label has a unique id such as "/m/068hy"
      for i1 in range(1, len(lines)):
          id = lines[i1][1]
          label = lines[i1][2]
          ids.append(id)
          labels.append(label)
      return labels

  if __name__ == '__main__':

      label_csv = './ast/egs/audioset/data/class_labels_indices.csv'

      # 1. make feature for predict
      wav_name = './ast/egs/audioset/data/0OxlgIitVig.wav'
      feats = make_features(wav_name, mel_bins=128)           # shape(1024, 128)

      # assume each input spectrogram has 100 time frames
      input_tdim = feats.shape[0]

      # 2. load the best model and the weights
      checkpoint_path = './ast/pretrained_models/audioset_10_10_0.4593.pth'
      ast_mdl = ASTModel(label_dim=527, input_tdim=input_tdim, imagenet_pretrain=False, audioset_pretrain=False)
      print(f'[*INFO] load checkpoint: {checkpoint_path}')
      checkpoint = torch.load(checkpoint_path, map_location='cuda')
      audio_model = torch.nn.DataParallel(ast_mdl, device_ids=[0])
      audio_model.load_state_dict(checkpoint)

      audio_model = audio_model.to(torch.device("cuda:0"))

      # 3. feed the data feature to model
      feats_data = feats.expand(1, input_tdim, 128)           # reshape the feature

      audio_model.eval()                                      # set the eval model
      with torch.no_grad():
          output = audio_model.forward(feats_data)
          output = torch.sigmoid(output)
      result_output = output.data.cpu().numpy()[0]

      # 4. map the post-prob to label
      labels = load_label(label_csv)

      sorted_indexes = np.argsort(result_output)[::-1]

      # Print audio tagging top probabilities
      for k in range(10):
          print('{}: {:.4f}'.format(np.array(labels)[sorted_indexes[k]],
                                    result_output[sorted_indexes[k]]))

      # output should be in shape [10, 527], i.e., 10 samples, each with prediction of 527 classes.
      # print(result_output.shape)

and the output: Speech: 0.1906 Music: 0.0481 Inside, small room: 0.0245 Musical instrument: 0.0100 Silence: 0.0088 Sound effect: 0.0074 Outside, rural or natural: 0.0064 Animal: 0.0058 Outside, urban or manmade: 0.0045 Inside, large room or hall: 0.0041

YuanGongND commented 2 years ago

Hi Jeff,

Thanks. Which part of the output makes you think it is not correct?

There might be other things - but I noticed that you didn't normalize your fbank, which will likely lead to wrong results. If you use the AudioSet pretrained model, please use:

fbank = (fbank - (-4.2677393)) / (4.5689974 * 2) before returning fbank.

Also, it is highly encouraged to set audioset_pretrain=True when initialize the AST model rather than manually load state_dict. It should be fine in your case, but if your target length is not 1024, the first method will automatically adjust the positional embedding for you.

Please let me know if that helps.

-Yuan

JeffC0628 commented 2 years ago

that's the point I missed, it's better now, thanks

tsw123tsw commented 10 months ago

Hi, Why do you switch torchaudio backend for interence #torchaudio.set_audio_backend("soundfile") ?

YuanGongND commented 10 months ago

@tsw123tsw

I guess it might related to some system package. But it is really not necessary, our official inference sample does not switch the backend. https://colab.research.google.com/github/YuanGongND/ast/blob/master/colab/AST_Inference_Demo.ipynb

-Yuan