Audio-WestlakeU / audiossl

A library built for easier audio self-supervised training, downstream tasks evaluation
Other
92 stars 10 forks source link

issue on ckpt loading in train_distill.py #14

Open folkartist opened 3 weeks ago

folkartist commented 3 weeks ago

I try to run train_distill.py to replicate the C2F method and then modify the code to fine-tune the model on my own datasets, I download both ckpts of clip and frame versions and set the corresponding args of the path. But I got error on the code sentence : state_dict_cls = cls_s["state_dict"] error message :

 KeyError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
'state_dict'
  File "/work/wpd/audiossl/train_distill.py", line 50, in <module>
    state_dict_cls = cls_s["state_dict"]

I tried changing "state_dict" to "student" but it cause a new error... How do I fix it? Thank you if you would like to help! and really good work you have done

lmaxwell commented 2 weeks ago

Hi, very sorry for the late response!

I try to run train_distill.py to replicate the C2F method and then modify the code to fine-tune the model on my own datasets, I download both ckpts of clip and frame versions and set the corresponding args of the path. But I got error on the code sentence : state_dict_cls = cls_s["state_dict"] error message :

 KeyError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
'state_dict'
  File "/work/wpd/audiossl/train_distill.py", line 50, in <module>
    state_dict_cls = cls_s["state_dict"]

I tried changing "state_dict" to "student" but it cause a new error... How do I fix it? Thank you if you would like to help! and really good work you have done

Which checkpoint did you use for the argument --pretrain_ckpt_path_clip? This should be the checkpoint of the ATST-Clip-Audioset, which means the clip model finetuned on audioset. I guess you used the checkpoint of ATST-Clip, which is wrong. I hope this can solve your problem.

I also check the train_distill.py, the code was not adapted to pytorch 2.1.1 and lightning 2.2.1. Please use the newest commit.