tarotez / sleepstages

16 stars 12 forks source link

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for cnn_lstm: #6

Closed impPie closed 1 year ago

impPie commented 1 year ago

When run python offline.py This Error raised

 File "c:\Users\ImpWa\Desktop\sleepstages-main\code\classifierClient.py", line 164, in setStagePredictor
    classifier.load_weights(model_path)
  File "c:\Users\ImpWa\Desktop\sleepstages-main\code\deepClassifier.py", line 615, in load_weights
    self.model.load_state_dict(torch.load(weight_path, map_location='cpu'), False)
  File "D:\Anaconda\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 1604, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for cnn_lstm:
        size mismatch for batns_for_stft.1.weight: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for batns_for_stft.1.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for batns_for_stft.1.running_mean: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for batns_for_stft.1.running_var: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for batns_for_stft.2.weight: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for batns_for_stft.2.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for batns_for_stft.2.running_mean: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for batns_for_stft.2.running_var: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for batns_for_stft.3.weight: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for batns_for_stft.3.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for batns_for_stft.3.running_mean: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for batns_for_stft.3.running_var: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for convs_for_stft.0.weight: copying a param with shape torch.Size([8, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([16, 1, 3, 3]).
        size mismatch for convs_for_stft.1.weight: copying a param with shape torch.Size([8, 8, 3, 3]) from checkpoint, the shape in current model is torch.Size([16, 16, 3, 3]).
        size mismatch for convs_for_stft.2.weight: copying a param with shape torch.Size([8, 8, 3, 3]) from checkpoint, the shape in current model is torch.Size([16, 16, 3, 3]).
        size mismatch for convs_for_stft.3.weight: copying a param with shape torch.Size([8, 8, 3, 3]) from checkpoint, the shape in current model is torch.Size([16, 16, 3, 3]).
        size mismatch for batn_combined.weight: copying a param with shape torch.Size([688]) from checkpoint, the shape in current model is torch.Size([736]).
        size mismatch for batn_combined.bias: copying a param with shape torch.Size([688]) from checkpoint, the shape in current model is torch.Size([736]).
        size mismatch for batn_combined.running_mean: copying a param with shape torch.Size([688]) from checkpoint, the shape in current model is torch.Size([736]).
        size mismatch for batn_combined.running_var: copying a param with shape torch.Size([688]) from checkpoint, the shape in current model is torch.Size([736]).
        size mismatch for final_fc_no_lstm.weight: copying a param with shape torch.Size([3, 688]) from checkpoint, the shape in current model is torch.Size([3, 736]).
        size mismatch for fulc_combined_lstm.weight: copying a param with shape torch.Size([32, 688]) from checkpoint, the shape in current model is torch.Size([32, 736]).
tarotez commented 1 year ago

After making the hyperparameters of the CNN for processing the STFT of EEG variable, the existing parameter files needed to be modified. In the latest commit, the following lines were added to specify the hyperparameters for a demo trained model.

"torch_filter_nums_for_stft" : [8,8,8,8],
"torch_kernel_sizes_for_stft" : [3,3,3,3],
"torch_strides_for_stft" : [1,2,2,2],