Closed JuneRen closed 1 year ago
The feature dimension of the input to SpeechFormer should be divisible by num_heads
.
If it is not divisible, the input features will be cropped:
self.input_dim = input_dim//num_heads * num_heads
if self.input_dim != x.shape[-1]: x = x[:, :, :self.input_dim]
The feature dimension at inference should be consistent with that at training. In your case, it seems that the feature dimension you used when training the model is 96, while the dimension you used when testing is 76. Please check this.
w
I have taken care of it, thank you very much
File "train_model_new.py", line 173, in test pred_logits = initail_model(features) File "/home/kaldi/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, kwargs) File "/data4/rj/SpeechFormer/model/speechformer.py", line 122, in forward x = self.layers(x).squeeze(dim=1) File "/home/kaldi/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, *kwargs) File "/home/kaldi/.local/lib/python3.7/site-packages/torch/nn/modules/container.py", line 117, in forward input = module(input) File "/home/kaldi/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(input, kwargs) File "/data4/rj/SpeechFormer/model/speechformer.py", line 83, in forward output = self.input_norm(x) File "/home/kaldi/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, *kwargs) File "/home/kaldi/.local/lib/python3.7/site-packages/torch/nn/modules/normalization.py", line 170, in forward input, self.normalized_shape, self.weight, self.bias, self.eps) File "/home/kaldi/.local/lib/python3.7/site-packages/torch/nn/functional.py", line 2095, in layer_norm torch.backends.cudnn.enabled) RuntimeError: Given normalized_shape=[96], expected input with shape [, 96], but got input of size[1, 101, 76]
因为代码里使用model_json['input_dim'] = (self.feat_dim // model_json['num_heads']) * model_json['num_heads'] 在推理的时候会报上述错误,这个您有遇到吗?谢谢