microsoft / DeBERTa

The implementation of DeBERTa
MIT License
1.99k stars 228 forks source link

is this a bug? in disentangled_attention.py pos_query_layer's dimension is 3, when use p2p attention and this code:\n pos_query = pos_query_layer[:,:,att_span:,:] \n get IndexError: too many indices for tensor of dimension 3 #51

Open hj-github1256 opened 3 years ago

hj-github1256 commented 3 years ago

in disentangled_attention.py
pos_query_layer's dimension is 3, but when select p2p attention, this code get IndexError

pos_query = pos_query_layer[:,:,att_span:,:]

test code:

########################################## import os os.chdir('F:\WorkSpace\DeBERTa-master')

import numpy as np

from future import absolute_import from future import division from future import print_function

import torch from torch.nn import CrossEntropyLoss from torch import optim

import math import pdb

from DeBERTa.deberta import from DeBERTa.utils import

from DeBERTa.deberta.config import ModelConfig from DeBERTa.apps.models.sequence_classification import SequenceClassificationModel

import os import time

import warnings warnings.filterwarnings('ignore') ##########################################

from FocalLoss import FocalLoss

from DualFocalLoss import Dual_Focal_loss

from circle_loss import convert_label_to_similarity,CircleLoss

from classify import Classify

import itertools

criterion_circle = CircleLoss(m=0.25, gamma=256)

criterion_focal_loss = FocalLoss(gamma=2.0, alpha=0.25, size_average=True)

criterion_focal_loss = FocalLoss(gamma=2.0, alpha=0.25, size_average=False)

criterion_dual_focal_loss = Dual_Focal_loss(ignore_lb=255, eps=1e-5, reduction='mean')

classifier=Classify(520,28,2)

classifier=Classify(8,16,2)

####################################### Data_dim=32 Data_S_npy_d5=138

config_dict={ "hidden_size" : 32, "num_hidden_layers" : 3, "num_attention_heads" : 8, "hidden_act" : "gelu", "intermediate_size" : 128, "hidden_dropout_prob" : 0.1, "attention_probs_dropout_prob" : 0.1, "max_position_embeddings" : 65, "type_vocab_size" : 0, "initializer_range" : 0.02, "layer_norm_eps" : 1e-7, "padding_idx" : 0, "vocab_size" : 68, "relative_attention" : True, "max_relative_positions" : 11, "position_buckets" : 8, "position_biased_input" : True, "pos_att_type" : "p2c|c2p|p2p" ##why p2p not work? } config_my=ModelConfig.from_dict(config_dict) DeBERTa_c=SequenceClassificationModel(config_my, num_labels=2, drop_out=None, pre_trained=None)

DeBERTa_c.double() DeBERTa_c.train() DeBERTa_c_optim = optim.Adam(DeBERTa_c.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, amsgrad=False)

optimizer_chain = optim.Adam(itertools.chain(DeBERTa_c.parameters(), classifier.parameters()), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, amsgrad=False)

device=torch.device('cpu')

batch_x_mark=None,batch_y_mark=None

train_loss = [] iter_count=0 iter_count += 1 device=torch.device('cpu') start_time = time.time() DeBERTa_c_optim.zero_grad() batch_x=np.random.randint(1,66,size=(64,65),dtype=int) batch_y=np.random.rand(64,2) batch_x=torch.from_numpy(batch_x) batch_y=torch.from_numpy(batch_y[:,1]).ge(0.5).long() batch_x = batch_x.double().to(device) logits,loss=DeBERTa_c(batch_x, type_ids=None, input_mask=None, labels=batch_y, position_ids=None)

got this

F:\WorkSpace\DeBERTa-master\DeBERTa\deberta\disentangled_attention.py in disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)

169 if 'p2p' in self.pos_att_type:

170 print("pos_query_layer.shape:",pos_query_layer.shape)

--> 171 pos_query = pos_query_layer[:,:,att_span:,:]

172 p2p_att = torch.matmul(pos_query, pos_key_layer.transpose(-1, -2))

173 p2p_att = p2p_att.expand(query_layer.size()[:2] + p2p_att.size()[2:])

IndexError: too many indices for tensor of dimension 3