Open lgstd opened 3 years ago
Never hit such kind of error before. Can you share how did you get this? Did you try our example scripts ?
import torch from DeBERTa.deberta.disentangled_attention import DisentangledSelfAttention from DeBERTa.deberta.config import ModelConfig config = ModelConfig() config.hidden_size = 128 config.num_attention_heads = 4 config.share_att_key = False config.pos_att_type = 'c2p|p2c|p2p' config.relative_attention = True config.position_buckets = -1 config.max_relative_positions = 512 config.max_position_embeddings = 768 config.hidden_dropout_prob = 0.1 config.attention_probs_dropout_prob = 0.1
attn = DisentangledSelfAttention(config) rel_embeddings = torch.nn.Embedding(getattr(config, 'max_relative_positions', 1024) * 2, config.hidden_size)
q = torch.normal(0,1,[5, 188, 128]) kv = torch.normal(0,1,[5, 77, 128]) out = attn(hidden_states=kv, attention_mask=None, return_att=False, query_states=q, relative_pos=None, rel_embeddings=rel_embeddings.weight)
------------------------------------------------------------------------------- output ------------------------------------------------------------
Traceback (most recent call last):
File "/home/deeplearn/pycharm/plugins/python-ce/helpers/pydev/pydevd.py", line 1477, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
File "/run/media/deeplearn/sata_04/manjaro_ai_apps/pycharm-community-2020.3.3/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/run/media/deeplearn/sata_04/LinuxProjects/github_practice/microsoft_deberta/tst/tst_001.py", line 25, in
In our current version, we only support key and query having the same length.
@BigBird01 Hi, thanks for your great contribution. I wanna know if line 158 - line 159 and line 164 - line 165 are removed, the position-to-content attention can support it when key and query have different length?
In addition to remove the above lines, you need to modify line 152 to:
r_pos = build_relative_position(key_layer.size(-2), query_layer.size(-2), bucket_size = self.position_buckets, max_position = self.max_relative_positions).to(query_layer.device)
and modify line 163 to
p2c_att = torch.gather(p2c_att, dim=-1, index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), query_layer.size(-2)])).transpose(-1,-2)
Does this fix the issue?
An error occurred while run in class DisentangledSelfAttention.forward() where query_states.size(1) > hidden_states.size(1):
https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/disentangled_attention.py line 165: p2c_att = torch.gather(p2c_att, dim=-2, index=pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))))