Nealcly / templateNER

Source code for template-based NER
208 stars 39 forks source link

Hard coded numbers in template_entity function of inference.py #21

Open khairunnisaor opened 2 years ago

khairunnisaor commented 2 years ago

Hi,

would you mind explaining some hard-coded numbers in the template_entity function from inference.py?

def template_entity(words, input_TXT, start):
    # input text -> template
    words_length = len(words)
    words_length_list = [len(i) for i in words]
    input_TXT = [input_TXT]*(5*words_length)

    input_ids = tokenizer(input_TXT, return_tensors='pt')['input_ids']
    model.to(device)
    template_list = [" is a location entity .", " is a person entity .", " is an organization entity .",
                     " is an other entity .", " is not a named entity ."]
    entity_dict = {0: 'LOC', 1: 'PER', 2: 'ORG', 3: 'MISC', 4: 'O'}
    temp_list = []
    for i in range(words_length):
        for j in range(len(template_list)):
            temp_list.append(words[i]+template_list[j])

    output_ids = tokenizer(temp_list, return_tensors='pt', padding=True, truncation=True)['input_ids']
    output_ids[:, 0] = 2
    output_length_list = [0]*5*words_length

    for i in range(len(temp_list)//5):
        base_length = ((tokenizer(temp_list[i * 5], return_tensors='pt', padding=True, truncation=True)['input_ids']).shape)[1] - 4
        output_length_list[i*5:i*5+ 5] = [base_length]*5
        output_length_list[i*5+4] += 1

    score = [1]*5*words_length
    with torch.no_grad():
        output = model(input_ids=input_ids.to(device), decoder_input_ids=output_ids[:, :output_ids.shape[1] - 2].to(device))[0]
        for i in range(output_ids.shape[1] - 3):
            # print(input_ids.shape)
            logits = output[:, i, :]
            logits = logits.softmax(dim=1)
            # values, predictions = logits.topk(1,dim = 1)
            logits = logits.to('cpu').numpy()
            # print(output_ids[:, i+1].item())
            for j in range(0, 5*words_length):
                if i < output_length_list[j]:
                    score[j] = score[j] * logits[j][int(output_ids[j][i + 1])]

    end = start+(score.index(max(score))//5)
        # score_list.append(score)
    return [start, end, entity_dict[(score.index(max(score))%5)], max(score)] #[start_index,end_index,label,score]

I learned from the opened issues that the 5s are the length of the template_list but how about the other numbers?

It would be a great help if you could response to this, thank you in advance!

0413Lemon commented 2 years ago

Have you solved this problem