YuanGongND / ast

Code for the Interspeech 2021 paper "AST: Audio Spectrogram Transformer".
BSD 3-Clause "New" or "Revised" License
1.13k stars 212 forks source link

load a trained model only for evaluation #28

Closed hbellafkir closed 3 years ago

hbellafkir commented 3 years ago

I have a case where I want to load a model, that I recently trained with target_length=512 and sample rate 48kHz, using the following code:

  sd = torch.load(model_path, map_location="cuda")
  audio_model = ASTModel(label_dim=84, fstride=10, tstride=10, input_fdim=128, input_tdim=512, imagenet_pretrain=False, audioset_pretrain=False, model_size='base384', verbose=False)
  audio_model = torch.nn.DataParallel(audio_model)
  audio_model.load_state_dict(sd, strict=False)

from load_state_dict I get that all keys matched successfully, but the evaluation fails with a random Mean Average Precision value, which doesn't match the value during training.

YuanGongND commented 3 years ago

Hi there,

I am not sure if it is a training/test mismatch or model loading issue. Are you using our code and get TEST mAPs every epoch during training? And that doesn't match with the test mAP if you load the model and do inference separately? If that is the case, is there a difference in data loading (especially the norm stats) in these two processes?

Or you only get mAP on the training set during training, and that doesn't match with the test mAP? I think it is normal that the training/test mAPs are different.

I didn't see an issue with your model loading if you also save your model as dataparallel object.

-Yuan

hbellafkir commented 3 years ago

this issue was solved by initializing the mlp head correctly using:

with torch.no_grad():
                self.mlp_head[0].weight = nn.Parameter(sd["module.mlp_head.0.weight"])
                self.mlp_head[0].bias = nn.Parameter(sd["module.mlp_head.0.bias"])
                self.mlp_head[1].weight = nn.Parameter(sd["module.mlp_head.1.weight"])
                self.mlp_head[1].bias = nn.Parameter(sd["module.mlp_head.1.bias"])
YuanGongND commented 3 years ago

It seems that your are using audio_model.mlp_heads rather than audio_model.module.mlp_heads for classification, which indicates your audio_model is not an torch.nn.dataparallel object, if so, not only your mlp_heads, but also all other parts of AST model should not in the audio_model.module dict but the audio_model dict.

One thing you could try is setting strict=True when you load the model and see how it says.