microsoft / DeBERTa

The implementation of DeBERTa
MIT License
1.98k stars 224 forks source link

RuntimeError: Index tensor must have the same number of dimensions as input tensor #33

Open lgstd opened 3 years ago

lgstd commented 3 years ago

An error occurred while run in class DisentangledSelfAttention.forward() where query_states.size(1) > hidden_states.size(1):

https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/disentangled_attention.py line 165: p2c_att = torch.gather(p2c_att, dim=-2, index=pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))))

BigBird01 commented 3 years ago

Never hit such kind of error before. Can you share how did you get this? Did you try our example scripts ?

lgstd commented 3 years ago

import torch from DeBERTa.deberta.disentangled_attention import DisentangledSelfAttention from DeBERTa.deberta.config import ModelConfig config = ModelConfig() config.hidden_size = 128 config.num_attention_heads = 4 config.share_att_key = False config.pos_att_type = 'c2p|p2c|p2p' config.relative_attention = True config.position_buckets = -1 config.max_relative_positions = 512 config.max_position_embeddings = 768 config.hidden_dropout_prob = 0.1 config.attention_probs_dropout_prob = 0.1

attn = DisentangledSelfAttention(config) rel_embeddings = torch.nn.Embedding(getattr(config, 'max_relative_positions', 1024) * 2, config.hidden_size)

q = torch.normal(0,1,[5, 188, 128]) kv = torch.normal(0,1,[5, 77, 128]) out = attn(hidden_states=kv, attention_mask=None, return_att=False, query_states=q, relative_pos=None, rel_embeddings=rel_embeddings.weight)

------------------------------------------------------------------------------- output ------------------------------------------------------------ Traceback (most recent call last): File "/home/deeplearn/pycharm/plugins/python-ce/helpers/pydev/pydevd.py", line 1477, in _exec pydev_imports.execfile(file, globals, locals) # execute the script File "/run/media/deeplearn/sata_04/manjaro_ai_apps/pycharm-community-2020.3.3/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile exec(compile(contents+"\n", file, 'exec'), glob, loc) File "/run/media/deeplearn/sata_04/LinuxProjects/github_practice/microsoft_deberta/tst/tst_001.py", line 25, in out = attn(hidden_states=kv, attention_mask=None, return_att=False, query_states=q, relative_pos=None, rel_embeddings=rel_embeddings.weight) File "/home/deeplearn/miniconda/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "/run/media/deeplearn/sata_04/LinuxProjects/github_practice/microsoft_deberta/DeBERTa-master-20210205/DeBERTa/deberta/disentangled_attention.py", line 89, in forward rel_att = self.disentangled_attention_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor) File "/run/media/deeplearn/sata_04/LinuxProjects/github_practice/microsoft_deberta/DeBERTa-master-20210205/DeBERTa/deberta/disentangled_attention.py", line 165, in disentangled_attention_bias p2c_att = torch.gather(p2c_att, dim=-2, index=pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))) RuntimeError: Index tensor must have the same number of dimensions as input tensor python-BaseException

BigBird01 commented 3 years ago

In our current version, we only support key and query having the same length.

SparkJiao commented 3 years ago

@BigBird01 Hi, thanks for your great contribution. I wanna know if line 158 - line 159 and line 164 - line 165 are removed, the position-to-content attention can support it when key and query have different length?

FreshAirTonight commented 3 years ago

In addition to remove the above lines, you need to modify line 152 to:

r_pos = build_relative_position(key_layer.size(-2), query_layer.size(-2), bucket_size = self.position_buckets, max_position = self.max_relative_positions).to(query_layer.device)

and modify line 163 to

p2c_att = torch.gather(p2c_att, dim=-1, index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), query_layer.size(-2)])).transpose(-1,-2)

Does this fix the issue?