RetroCirce / HTS-Audio-Transformer

The official code repo of "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection"
https://arxiv.org/abs/2202.00874
MIT License
341 stars 62 forks source link

Training and infering with dataset containing 4 classes #31

Closed JonathanFL closed 1 year ago

JonathanFL commented 1 year ago

Hi Ke,

I hope you are having a happy holiday :)

I have used htsat_esc_training.ipynb to retrain a model on my data (removed all the ESC-50 data and replaced it with my data, split into 5 folds), which only contains 4 classes. When I load the state_dict for predicting, the following error happens:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In [2], line 2
      1 # Inference
----> 2 Audiocls = Audio_Classification(model_path, config)
      4 pred_prob = Audiocls.predict("scptj-6t8dl.wav")
      6 print('Audiocls predict output: ', pred_prob)

Cell In [1], line 44, in Audio_Classification.__init__(self, model_path, config)
     42 for key in ckpt["state_dict"]:
     43     temp_ckpt[key[10:]] = ckpt['state_dict'][key]
---> 44 self.sed_model.load_state_dict(temp_ckpt)
     45 self.sed_model.to(self.device)
     46 self.sed_model.eval()

File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\torch\nn\modules\module.py:1604, in Module.load_state_dict(self, state_dict, strict)
   1599         error_msgs.insert(
   1600             0, 'Missing key(s) in state_dict: {}. '.format(
   1601                 ', '.join('"{}"'.format(k) for k in missing_keys)))
   1603 if len(error_msgs) > 0:
-> 1604     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   1605                        self.__class__.__name__, "\n\t".join(error_msgs)))
   1606 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for HTSAT_Swin_Transformer:
    size mismatch for tscam_conv.weight: copying a param with shape torch.Size([50, 768, 2, 3]) from checkpoint, the shape in current model is torch.Size([4, 768, 2, 3]).
    size mismatch for tscam_conv.bias: copying a param with shape torch.Size([50]) from checkpoint, the shape in current model is torch.Size([4]).
    size mismatch for head.weight: copying a param with shape torch.Size([50, 50]) from checkpoint, the shape in current model is torch.Size([4, 4]).
    size mismatch for head.bias: copying a param with shape torch.Size([50]) from checkpoint, the shape in current model is torch.Size([4]).

I have changed the model_path to be one of the checkpoints from when I trained the model: model_path = './workspace/results/exp_htsat_esc_50/checkpoint/lightning_logs/version_34/checkpoints/l-epoch=38-acc=1.000.ckpt' This confuses me a lot. Do you have any idea why this is happening?

JonathanFL commented 1 year ago

It worked after the notebook was restarted. Could you delete this issue? Thanks.