flairNLP / flair

A very simple framework for state-of-the-art Natural Language Processing (NLP)
https://flairnlp.github.io/flair/
Other
13.7k stars 2.08k forks source link

[Question]: I got different results when I loaded a trained model to do the inference task on the same test data #3478

Open xiulinyang opened 3 days ago

xiulinyang commented 3 days ago

Question

Hi,

I'm training a SequenceTagger to do the NER task using my customed dataset with the customed features. After training was done, I got a file named test.tsv which is the prediction of the test split. However, when I loaded the trained model (final-model.pt) and did inference on the same test data, I got way lower results (0.86 vs 0.64 in accuracy).

Here is the prediction function I'm using. I checked the sents list, and all the labels are correctly added to each token. During training, I stacked all the features - should I do the same during prediction? I find the main issue is that the model does not understand the labels added - I have a specific label named __TARGET__ which is the signal for the model to give predictions on specific tokens, but it seems that the model ignores the tag. It would be very appreciated for any suggestions. Thanks!

def predict(self, in_path=None, in_format="conllu", out_format="flair", as_text=False, sent="s"):
        model = self.model
        model.eval()
        tagcol = -1
        poscol = 1
        depcol = 2
        attachcol = 3
        arg1col = 4
        arg2col = 5
        arg3col = 6
        if as_text:
            data = in_path
            #data = (data + "\n").replace("<s>\n", "").replace("</s>\n", "\n").strip()
        else:
            data = io.open(in_path,encoding="utf8").read()
        sents = []
        words = []
        true_tags = []
        true_pos = []
        true_dep = []
        true_attach = []
        true_arg1 = []
        true_arg2 = []
        true_arg3 = []
        data = data.strip() + "\n"  # Ensure final new line for last sentence
        for line in data.split("\n"):
            if len(line.strip())==0 or in_format=="sgml" and "</" + sent + ">" in line:
                if len(words) > 0:
                    if flair_version > 8:
                        tokenizer = False
                    else:
                        tokenizer = lambda x:x.split(" ")
                    sents.append(Sentence(" ".join(words),use_tokenizer=tokenizer))
                    for i, word in enumerate(sents[-1]):
                            word.add_label("upos",true_pos[i])
                            word.add_label('deprel', true_dep[i])
                            word.add_label('attach', true_attach[i])
                            word.add_label('arg1', true_arg1[i])
                            word.add_label('arg2', true_arg2[i])
                            word.add_label('arg3', true_arg3[i])
                    words = []
                    true_pos = []
                    true_dep = []
                    true_attach = []
                    true_arg1 = []
                    true_arg2 = []
                    true_arg3 = []
            else:
                  if "\t" in line:
                      fields = line.split("\t")
                      if "." in fields[0]:
                          continue
                      if "-" in fields[0]:
                          continue
                      words.append(line.split("\t")[0])
                      true_tags.append(line.split("\t")[tagcol])
                      true_pos.append(line.split("\t")[poscol])
                      true_attach.append(line.split("\t")[attachcol])
                      true_dep.append(line.split("\t")[depcol])
                      true_arg1.append(line.split("\t")[arg1col])
                      true_arg2.append(line.split("\t")[arg2col])
                      true_arg3.append(line.split("\t")[arg3col])

        # predict tags and print
        if flair_version > 8:
            model.predict(sents, force_token_predictions=True, return_probabilities_for_all_classes=True)
        else:
            model.predict(sents)#, all_tag_prob=True)

        preds = []
        scores = []
        words = []
        for i, sent in enumerate(sents):
            for tok in sent.tokens:
                if flair_version > 8:
                    pred = tok.labels[6].value if len(tok.labels)>6 else "O"
                    score = tok.labels[6].score if len(tok.labels) > 6 else "1.0"
                else:
                    pred = tok.labels[6].value
                    score = str(tok.labels[6].score)
                preds.append(pred)
                scores.append(score)
                words.append(tok.text)
                tok.clear_embeddings()  # Without this, there will be an OOM issue

        toknum = 0
        output = []
        #out_format="diff"
        for i, sent in enumerate(sents):
            tid=1
            if i>0 and out_format=="conllu":
                output.append("")
            for tok in sent.tokens:
                pred = preds[toknum]
                score = str(scores[toknum])
                if len(score)>5:
                    score = score[:5]
                if out_format == "conllu":
                    pred = pred if not pred == "O" else "_"
                    fields = [str(tid),tok.text,"_",pred,pred,"_","_","_","_","_"]
                    output.append("\t".join(fields))
                    tid+=1
                elif out_format == "xg":
                    output.append("\t".join([pred, tok.text, score]))
                elif out_format == "tt":
                    output.append("\t".join([tok.text, pred]))
                else:
                    true_tag = true_tags[toknum]
                    corr = "T" if true_tag == pred else "F"
                    output.append("\t".join([tok.text, pred, true_tag, corr, score]))
                toknum += 1

        if as_text:
            return "\n".join(output)
        else:
            ext = "gme.conllu" if out_format == "conllu" else "txt"
            partition = "test" if "test" in in_path else "dev"
            with io.open(script_dir + TRAIN_PATH +os.sep + "flair-"+partition+"-pred." + ext,'w',encoding="utf8",newline="\n") as f:
                f.write("\n".join(output))
alanakbik commented 2 days ago

Hello @xiulinyang it is hard to tell without a runnable example to reproduce the error. But I would suspect the error lies somewhere in the way you read the sentences.

Some ideas:

xiulinyang commented 1 day ago

Hi @alanakbik, thanks a lot for your prompt reply! :)

I tried to use ColumnCorpus in Flair (as seen below) but the issue remains. Is it possible for me to send you the code, training log and the data for reproduce the error? Or is there anything that I misunderstood in the code below? Thanks!

   def predict(self, in_path=None, as_text=False):
        model = self.model
        if as_text:
            data = in_path
            #data = (data + "\n").replace("<s>\n", "").replace("</s>\n", "\n").strip()
        else:
            data = io.open(in_path,encoding="utf8").read()
        true_tags = []

        data = data.strip() + "\n"  # Ensure final new line for last sentence
        for line in data.split("\n"):
            if "\t" in line:
                true_tags.append(line.split("\t")[7])

        columns = {0: "text", 1: "upos", 2: "deprel", 3: "attach", 4: "arg1", 5: "arg2", 6: "arg3", 7: "ner"}

        corpus: Corpus = ColumnCorpus(data_folder, columns, train_file="train.biodep.sample.tab",
                                      test_file="test.biodep.sample.tab", dev_file="dev.biodep.sample.tab", )

        output = []
        toknum = 0
        model.predict(corpus.dev, force_token_predictions=True, return_probabilities_for_all_classes=True))

        for sentence in corpus.dev:
            for token in sentence:

                text = token.text
                if len(token.labels) > 6:
                    ner_tag = token.labels[6].value
                    ner_value = token.labels[6].score
                else:
                    ner_tag = "O"
                    ner_value = "1.0"

                corr = 'T' if ner_tag == true_tags[toknum] else 'F'
                output.append(text+'\t'+true_tags[toknum]+'\t'+ner_tag+ '\t'+ corr +'\t' + ner_value)
                toknum +=1

        partition = "test" if "test" in in_path else "dev"
        with io.open(script_dir + TRAIN_PATH +os.sep + "flair-"+partition+"-pred.tsv",'w',encoding="utf8",newline="\n") as f:
            f.write("\n".join(output))