Closed myatmyintzuthin closed 2 years ago
Hi thanks! you can use passt_hear21 like this:
# Loading the weights
p ="output/esc50/_None/checkpoints/epoch=4-step=2669.ckpt"
ckpt = torch.load(p)
net_statedict = {k[4:]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net.")} # main weights
net_swa = {k[len("net_swa."):]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net_swa.")} # swa weights
# getting the model
from hear21passt.base import load_model, get_scene_embeddings, get_timestamp_embeddings
model = load_model(mode="logits").cuda()
model.net.load_state_dict(net_statedict) # loading the fine-tuned weights
# example
wave_example = torch.ones((3, 32000 * 5))*0.5
logits = model(wave_example)
Thank you so much for the reply. This is my first time of creating inference script in PyTorch. It was a great help to me. I am gonna share my inference script here in case someone wants to use.
# References
# 1) https://github.com/YuanGongND/ast/blob/master/egs/audioset/inference.py
# 2) https://github.com/kkoutini/passt_hear21
import csv
import argparse
import numpy as np
import torch
import torchaudio
from pytorch_lightning import Trainer as plTrainer
from hear21passt.base import load_model
def load_label(label_csv):
with open(label_csv, 'r') as f:
reader = csv.reader(f, delimiter=',')
lines = list(reader)
labels = []
ids = [] # Each label has a unique id such as "/m/068hy"
for i1 in range(1, len(lines)):
id = lines[i1][1]
label = lines[i1][2]
ids.append(id)
labels.append(label)
return labels
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Example of parser:'
'python inference --audio_path ESC-50-master/audio_32k/1-5996-A-6.wav '
'--model_path checkpoints/epoch=2-step=4799.ckpt')
parser.add_argument("--model_path", required= True,type=str,
help="the trained model you want to test")
parser.add_argument("--audio_path", required= True,
help='the audio you want to predict, sample rate 32k.',
type=str)
args = parser.parse_args()
label_csv = './esc50/esc_class_labels_indices.csv' # label and indices for ESC-50 data
# 1. load audio file
audio_path = args.audio_path
waveform, _ = torchaudio.load(audio_path)
# 2. load checkpoint
checkpoint_path = args.model_path
ckpt = torch.load(checkpoint_path)
net_statedict = {k[4:]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net.")} # main weights
net_swa = {k[len("net_swa."):]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net_swa.")} # swa weights
# 3. loading the fine-tuned weights
passt_model = load_model(mode="logits").cuda()
passt_model.net.load_state_dict(net_statedict)
trainer = plTrainer(gpus=1)
print(f'[*INFO] load checkpoint: {checkpoint_path}')
passt_model = passt_model.to(torch.device("cuda:0"))
waveform = waveform.to(torch.device("cuda:0"))
with torch.no_grad():
output = passt_model(waveform)
output = torch.sigmoid(output)
result_output = output.data.cpu().numpy()[0]
# 4. map the post-prob to label
labels = load_label(label_csv)
sorted_indexes = np.argsort(result_output)[::-1]
# Print audio tagging top probabilities
print('[*INFO] predict results:')
for k in range(5):
print('{}: {:.4f}'.format(np.array(labels)[sorted_indexes[k]],
result_output[sorted_indexes[k]]))
Hello, authors. Thank you for sharing the great work.
I tried to fine-tuned AudioSet pretrained model
passt-s-f128-p16-s10-ap.476-swa.pt
on ESC-50 dataset by usingex_esc50.py
. I got checkpoints saved inoutput/esc50/_None/checkpoints/epoch=4-step=2669.ckpt
. I want to load the checkpoint and inference with audio file. I am trying to load the checkpoint model and tried to used passt_hear21 for inference but kinda lost track of the process.Could you please share how to inference with the saved checkpoints on audio file?