Closed 980202006 closed 2 years ago
you can use the fine-tuned model on openmic here openmic-passt-s-f128-10sec-p16-s10-ap.85.pt
!pip install git+https://github.com/kkoutini/passt_hear21.git
import torch
from hear21passt.base import get_basic_model
from hear21passt.models.passt import get_model as get_model_passt
model = get_basic_model(mode="logits").cuda()
# replace the transformer for the 20 classes output
model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476", n_classes=20)
# loading the pretrained model
state = torch.load("/path/to/openmic-passt-s-f128-10sec-p16-s10-ap.85.pt")
model.net.load_state_dict(state)
logits = model(wave_signal)
What is the order of the label set of logits?
the labels are in the same order as in the "openmic-2018.npz" file in the dataset here. I think it's the chronological order of the classes.
Thank you!
Can I output the labels directly with the pretrained model or do I need to do fine-tuning for openmic-2018