Open khairunnisaor opened 2 years ago
Hi,
would you mind explaining some hard-coded numbers in the template_entity function from inference.py?
template_entity
inference.py
def template_entity(words, input_TXT, start): # input text -> template words_length = len(words) words_length_list = [len(i) for i in words] input_TXT = [input_TXT]*(5*words_length) input_ids = tokenizer(input_TXT, return_tensors='pt')['input_ids'] model.to(device) template_list = [" is a location entity .", " is a person entity .", " is an organization entity .", " is an other entity .", " is not a named entity ."] entity_dict = {0: 'LOC', 1: 'PER', 2: 'ORG', 3: 'MISC', 4: 'O'} temp_list = [] for i in range(words_length): for j in range(len(template_list)): temp_list.append(words[i]+template_list[j]) output_ids = tokenizer(temp_list, return_tensors='pt', padding=True, truncation=True)['input_ids'] output_ids[:, 0] = 2 output_length_list = [0]*5*words_length for i in range(len(temp_list)//5): base_length = ((tokenizer(temp_list[i * 5], return_tensors='pt', padding=True, truncation=True)['input_ids']).shape)[1] - 4 output_length_list[i*5:i*5+ 5] = [base_length]*5 output_length_list[i*5+4] += 1 score = [1]*5*words_length with torch.no_grad(): output = model(input_ids=input_ids.to(device), decoder_input_ids=output_ids[:, :output_ids.shape[1] - 2].to(device))[0] for i in range(output_ids.shape[1] - 3): # print(input_ids.shape) logits = output[:, i, :] logits = logits.softmax(dim=1) # values, predictions = logits.topk(1,dim = 1) logits = logits.to('cpu').numpy() # print(output_ids[:, i+1].item()) for j in range(0, 5*words_length): if i < output_length_list[j]: score[j] = score[j] * logits[j][int(output_ids[j][i + 1])] end = start+(score.index(max(score))//5) # score_list.append(score) return [start, end, entity_dict[(score.index(max(score))%5)], max(score)] #[start_index,end_index,label,score]
I learned from the opened issues that the 5s are the length of the template_list but how about the other numbers?
5
template_list
It would be a great help if you could response to this, thank you in advance!
Have you solved this problem
Hi,
would you mind explaining some hard-coded numbers in the
template_entity
function frominference.py
?I learned from the opened issues that the
5
s are the length of thetemplate_list
but how about the other numbers?It would be a great help if you could response to this, thank you in advance!