YuanGongND / ssast

Code for the AAAI 2022 paper "SSAST: Self-Supervised Audio Spectrogram Transformer".
BSD 3-Clause "New" or "Revised" License
357 stars 58 forks source link

Use SSAST pretrained model to inference #36

Open gavinwwf opened 2 weeks ago

gavinwwf commented 2 weeks ago

@YuanGongND I used SSAST pretrained model to inference, but got the different results every time. And every score in the results is close. What is the reason for the result? [{"label": "Electric toothbrush", "score": 0.849498987197876}, {"label": "Blender", "score": 0.8397527933120728}, {"label": "Tambourine", "score": 0.8310427665710449}, {"label": "Race car, auto racing", "score": 0.8218237161636353}, {"label": "Pink noise", "score": 0.8042027354240417}, {"label": "Writing", "score": 0.7958802580833435}, {"label": "Singing", "score": 0.7875975966453552}, {"label": "Telephone dialing, DTMF", "score": 0.7849113941192627}, {"label": "Ambulance (siren)", "score": 0.7678646445274353}, {"label": "Country", "score": 0.7541956901550293}]

My code is as follows:

def model_fn(model_dir): """ Load the model and set weights """

# Load the model
input_tdim = 200
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_path = f'{model_dir}/SSAST-Tiny-Patch-400.pth'

# fstride, tstride = int(checkpoint_path.split('/')[-1].split('_')[1]), int(
#     checkpoint_path.split('/')[-1].split('_')[2].split('.')[0])

ast_mdl = ASTModel(label_dim=527, fshape=16, tshape=16, fstride=10, tstride=10, input_tdim=input_tdim,
                        model_size='tiny', pretrain_stage=False, load_pretrained_mdl_path=checkpoint_path)

audio_model = torch.nn.DataParallel(ast_mdl)
checkpoint = load_modified_checkpoint(checkpoint_path, audio_model, device)
audio_model.load_state_dict(checkpoint)
audio_model = audio_model.to(device)
audio_model.eval()

labels = load_label(f'{model_dir}/class_labels_indices.csv')

return audio_model, labels

def predict_fn(input_data, model): """ The predict_fn is invoked with the return value of input_fn. """ audio_model, labels = model device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

input_tdim = 200
feats = make_features(input_data, mel_bins=128, target_length=input_tdim)
feats_data = feats.expand(1, input_tdim, 128).to(device)

with torch.no_grad():
    output = audio_model(feats_data, task='ft_cls')
    output = torch.sigmoid(output)

result_output = output.data.cpu().numpy()[0]
sorted_indexes = np.argsort(result_output)[::-1]

top_k = 10
top_k_labels = [(labels[idx], result_output[idx]) for idx in sorted_indexes[:top_k]]

return top_k_labels