YuanGongND / ast

Code for the Interspeech 2021 paper "AST: Audio Spectrogram Transformer".
BSD 3-Clause "New" or "Revised" License
1.13k stars 212 forks source link

Wrong .pth name? #14

Closed jvel07 closed 3 years ago

jvel07 commented 3 years ago

Hi, thanks for the awesome contribution! I have prepared my data using your pipeline. When running the experiments, I get:

ImageNet pretraining: True, AudioSet pretraining: True
Traceback (most recent call last):
  File "../../src/run.py", line 99, in <module>
    audioset_pretrain=args.audioset_pretrain, model_size='base384')
  File "/home/user/PycharmProjects/ast/src/models/ast_models.py", line 143, in __init__
    sd = torch.load('../../pretrained_models/ast_audioset.pth', map_location=device)
  File "/home/user/anaconda3/envs/ast/lib/python3.7/site-packages/torch/serialization.py", line 579, in load
    with _open_file_like(f, 'rb') as opened_file:
  File "/home/user/anaconda3/envs/ast/lib/python3.7/site-packages/torch/serialization.py", line 230, in _open_file_like
    return _open_file(name_or_buffer, mode)
  File "/home/user/anaconda3/envs/ast/lib/python3.7/site-packages/torch/serialization.py", line 211, in __init__
    super(_open_file, self).__init__(open(name, mode))
FileNotFoundError: [Errno 2] No such file or directory: '../../pretrained_models/ast_audioset.pth'

Looking at the ast_models.py file in line 143, the name of the model is different from the name of the downloaded model:

wget.download(audioset_mdl_url, out='../../pretrained_models/audioset_10_10_0.4593.pth')
sd = torch.load('../../pretrained_models/ast_audioset.pth', map_location=device)

Changing the name of the file from "ast_audioset.pth" to "audioset_10_10_0.4593.pth" just fixed the missing file error. Posted this in case someone needs it.

YuanGongND commented 3 years ago

Hi there,

Thanks so much for pointing this out! You are correct that the saving path was wrong, I have just fixed it in ast_models.py.

-Yuan