kkoutini / PaSST

Efficient Training of Audio Transformers with Patchout
Apache License 2.0
305 stars 50 forks source link

Inference ESC-50 fine-tuned model #12

Closed myatmyintzuthin closed 2 years ago

myatmyintzuthin commented 2 years ago

Hello, authors. Thank you for sharing the great work.

I tried to fine-tuned AudioSet pretrained model passt-s-f128-p16-s10-ap.476-swa.pt on ESC-50 dataset by using ex_esc50.py. I got checkpoints saved in output/esc50/_None/checkpoints/epoch=4-step=2669.ckpt. I want to load the checkpoint and inference with audio file. I am trying to load the checkpoint model and tried to used passt_hear21 for inference but kinda lost track of the process.

Could you please share how to inference with the saved checkpoints on audio file?

kkoutini commented 2 years ago

Hi thanks! you can use passt_hear21 like this:

# Loading the weights
p ="output/esc50/_None/checkpoints/epoch=4-step=2669.ckpt"
ckpt = torch.load(p)
net_statedict = {k[4:]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net.")} # main weights
net_swa  = {k[len("net_swa."):]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net_swa.")} # swa weights

# getting the model
from hear21passt.base import load_model, get_scene_embeddings, get_timestamp_embeddings

model = load_model(mode="logits").cuda()
model.net.load_state_dict(net_statedict) # loading the fine-tuned weights

# example
wave_example = torch.ones((3, 32000 * 5))*0.5 
logits = model(wave_example)
myatmyintzuthin commented 2 years ago

Thank you so much for the reply. This is my first time of creating inference script in PyTorch. It was a great help to me. I am gonna share my inference script here in case someone wants to use.

# References 
# 1) https://github.com/YuanGongND/ast/blob/master/egs/audioset/inference.py
# 2) https://github.com/kkoutini/passt_hear21

import csv
import argparse
import numpy as np
import torch
import torchaudio
from pytorch_lightning import Trainer as plTrainer
from hear21passt.base import load_model

def load_label(label_csv):
    with open(label_csv, 'r') as f:
        reader = csv.reader(f, delimiter=',')
        lines = list(reader)
    labels = []
    ids = []  # Each label has a unique id such as "/m/068hy"
    for i1 in range(1, len(lines)):
        id = lines[i1][1]
        label = lines[i1][2]
        ids.append(id)
        labels.append(label)
    return labels

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Example of parser:'
                                                 'python inference --audio_path ESC-50-master/audio_32k/1-5996-A-6.wav '
                                                 '--model_path checkpoints/epoch=2-step=4799.ckpt')

    parser.add_argument("--model_path", required= True,type=str,
                        help="the trained model you want to test")
    parser.add_argument("--audio_path", required= True,
                        help='the audio you want to predict, sample rate 32k.',
                        type=str)

    args = parser.parse_args()

    label_csv = './esc50/esc_class_labels_indices.csv'       # label and indices for ESC-50 data

    # 1. load audio file
    audio_path = args.audio_path
    waveform, _ = torchaudio.load(audio_path)

    # 2. load checkpoint
    checkpoint_path = args.model_path

    ckpt = torch.load(checkpoint_path)
    net_statedict = {k[4:]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net.")} # main weights
    net_swa  = {k[len("net_swa."):]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net_swa.")} # swa weights

    # 3. loading the fine-tuned weights
    passt_model = load_model(mode="logits").cuda()
    passt_model.net.load_state_dict(net_statedict)

    trainer = plTrainer(gpus=1)
    print(f'[*INFO] load checkpoint: {checkpoint_path}')

    passt_model = passt_model.to(torch.device("cuda:0"))
    waveform = waveform.to(torch.device("cuda:0"))

    with torch.no_grad():
        output = passt_model(waveform)
        output = torch.sigmoid(output)
    result_output = output.data.cpu().numpy()[0]

    # 4. map the post-prob to label
    labels = load_label(label_csv)

    sorted_indexes = np.argsort(result_output)[::-1]

    # Print audio tagging top probabilities
    print('[*INFO] predict results:')
    for k in range(5):
        print('{}: {:.4f}'.format(np.array(labels)[sorted_indexes[k]],
                                  result_output[sorted_indexes[k]]))