kkoutini / PaSST

Efficient Training of Audio Transformers with Patchout
Apache License 2.0
305 stars 50 forks source link

Is it possible to use this project directly for a code example for instrument recognition? #15

Closed 980202006 closed 2 years ago

980202006 commented 2 years ago

Can I output the labels directly with the pretrained model or do I need to do fine-tuning for openmic-2018

kkoutini commented 2 years ago

you can use the fine-tuned model on openmic here openmic-passt-s-f128-10sec-p16-s10-ap.85.pt

!pip install git+https://github.com/kkoutini/passt_hear21.git
import torch

from hear21passt.base import  get_basic_model
from hear21passt.models.passt import get_model as get_model_passt
model = get_basic_model(mode="logits").cuda()
# replace the transformer for the  20 classes output
model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476", n_classes=20)

# loading the pretrained model
state = torch.load("/path/to/openmic-passt-s-f128-10sec-p16-s10-ap.85.pt")
model.net.load_state_dict(state)

logits = model(wave_signal)
980202006 commented 2 years ago

What is the order of the label set of logits?

kkoutini commented 2 years ago

the labels are in the same order as in the "openmic-2018.npz" file in the dataset here. I think it's the chronological order of the classes.

980202006 commented 2 years ago

Thank you!