kuratahiroyuki / BERT6mA

This package provides the CLI of BERT6mA. The BERT6mA predicts 6mA sites for 11 species.
MIT License
4 stars 2 forks source link

RuntimeError: mat1 and mat2 shapes cannot be multiplied (76000x1 and 100x400) #3

Open xzx0554 opened 4 months ago

xzx0554 commented 4 months ago

The error message I receive when using it is as follows, how can I solve this issue?

RuntimeError Traceback (most recent call last) Cell In[1], line 113 110 dataset = import_fasta(test_path) 112 net = burt_process(out_path, deep_model_path = common_path + "/deepmodel/6mA" + 'A.thaliana' + "/deep_model", batch_size =20, thresh = float(0.5)) --> 113 net.pre_training(dataset, w2v_model)

Cell In[1], line 86 84 for i, (emb_mat, seq_id) in enumerate(loader): 85 with torch.no_grad(): ---> 86 outputs = net(emb_mat) 88 probs.extend(outputs.cpu().detach().squeeze(1).numpy().flatten().tolist()) 89 pred_labels.extend((np.array(outputs.cpu().detach().squeeze(1).numpy()) + 1 - self.thresh).astype(np.int16))

File ~\AppData\Roaming\Python\Python311\site-packages\torch\nn\modules\module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File ~\AppData\Roaming\Python\Python311\site-packages\torch\nn\modules\module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File c:\Users\20132\Desktop\项目\remake_6ma\BERT6mA-main\BERT6mA\predict\Bert_network.py:84, in BERT.forward(self, output) 82 self.attn_list = [] 83 for layer in self.layers: ---> 84 output, enc_self_attn = layer(output) 85 self.attn_list.append(enc_self_attn) 87 output = output.view(output.size(0), -1)

File ~\AppData\Roaming\Python\Python311\site-packages\torch\nn\modules\module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File ~\AppData\Roaming\Python\Python311\site-packages\torch\nn\modules\module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File c:\Users\20132\Desktop\项目\remake_6ma\BERT6mA-main\BERT6mA\predict\Bert_network.py:69, in EncoderLayer.forward(self, enc_inputs) 67 def forward(self, enc_inputs): ---> 69 enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs) 70 enc_outputs = self.pos_ffn(enc_outputs) 72 return enc_outputs, attn

File ~\AppData\Roaming\Python\Python311\site-packages\torch\nn\modules\module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File ~\AppData\Roaming\Python\Python311\site-packages\torch\nn\modules\module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File c:\Users\20132\Desktop\项目\remake_6ma\BERT6mA-main\BERT6mA\predict\Bert_network.py:39, in MultiHeadAttention.forward(self, Q, K, V) 36 def forward(self, Q, K, V): 37 residual, batch_size = Q, Q.size(0) ---> 39 q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_q).transpose(1,2) 41 k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2) 42 v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1,2)

File ~\AppData\Roaming\Python\Python311\site-packages\torch\nn\modules\module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File ~\AppData\Roaming\Python\Python311\site-packages\torch\nn\modules\module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File ~\AppData\Roaming\Python\Python311\site-packages\torch\nn\modules\linear.py:114, in Linear.forward(self, input) 113 def forward(self, input: Tensor) -> Tensor: --> 114 return F.linear(input, self.weight, self.bias)

xzx0554 commented 4 months ago

This appears to be code:

q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_q).transpose(1,2)

k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)
v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1,2),

The issue is maybe caused by a mismatch in tensor shapes. When I try to match the tensor shapes by adding:

Q = Q.squeeze(-1)
K = K.squeeze(-1)
V = V.squeeze(-1)

the model fails to initialize properly.