flairNLP / flair

A very simple framework for state-of-the-art Natural Language Processing (NLP)
https://flairnlp.github.io/flair/
Other
13.96k stars 2.1k forks source link

[Question]: I got different results when I loaded a trained model to do the inference task on the same test data #3478

Open xiulinyang opened 5 months ago

xiulinyang commented 5 months ago

Question

Hi,

I'm training a SequenceTagger to do the NER task using my customed dataset with the customed features. After training was done, I got a file named test.tsv which is the prediction of the test split. However, when I loaded the trained model (final-model.pt) and did inference on the same test data, I got way lower results (0.86 vs 0.64 in accuracy).

Here is the prediction function I'm using. I checked the sents list, and all the labels are correctly added to each token. During training, I stacked all the features - should I do the same during prediction? I find the main issue is that the model does not understand the labels added - I have a specific label named __TARGET__ which is the signal for the model to give predictions on specific tokens, but it seems that the model ignores the tag. It would be very appreciated for any suggestions. Thanks!

def predict(self, in_path=None, in_format="conllu", out_format="flair", as_text=False, sent="s"):
        model = self.model
        model.eval()
        tagcol = -1
        poscol = 1
        depcol = 2
        attachcol = 3
        arg1col = 4
        arg2col = 5
        arg3col = 6
        if as_text:
            data = in_path
            #data = (data + "\n").replace("<s>\n", "").replace("</s>\n", "\n").strip()
        else:
            data = io.open(in_path,encoding="utf8").read()
        sents = []
        words = []
        true_tags = []
        true_pos = []
        true_dep = []
        true_attach = []
        true_arg1 = []
        true_arg2 = []
        true_arg3 = []
        data = data.strip() + "\n"  # Ensure final new line for last sentence
        for line in data.split("\n"):
            if len(line.strip())==0 or in_format=="sgml" and "</" + sent + ">" in line:
                if len(words) > 0:
                    if flair_version > 8:
                        tokenizer = False
                    else:
                        tokenizer = lambda x:x.split(" ")
                    sents.append(Sentence(" ".join(words),use_tokenizer=tokenizer))
                    for i, word in enumerate(sents[-1]):
                            word.add_label("upos",true_pos[i])
                            word.add_label('deprel', true_dep[i])
                            word.add_label('attach', true_attach[i])
                            word.add_label('arg1', true_arg1[i])
                            word.add_label('arg2', true_arg2[i])
                            word.add_label('arg3', true_arg3[i])
                    words = []
                    true_pos = []
                    true_dep = []
                    true_attach = []
                    true_arg1 = []
                    true_arg2 = []
                    true_arg3 = []
            else:
                  if "\t" in line:
                      fields = line.split("\t")
                      if "." in fields[0]:
                          continue
                      if "-" in fields[0]:
                          continue
                      words.append(line.split("\t")[0])
                      true_tags.append(line.split("\t")[tagcol])
                      true_pos.append(line.split("\t")[poscol])
                      true_attach.append(line.split("\t")[attachcol])
                      true_dep.append(line.split("\t")[depcol])
                      true_arg1.append(line.split("\t")[arg1col])
                      true_arg2.append(line.split("\t")[arg2col])
                      true_arg3.append(line.split("\t")[arg3col])

        # predict tags and print
        if flair_version > 8:
            model.predict(sents, force_token_predictions=True, return_probabilities_for_all_classes=True)
        else:
            model.predict(sents)#, all_tag_prob=True)

        preds = []
        scores = []
        words = []
        for i, sent in enumerate(sents):
            for tok in sent.tokens:
                if flair_version > 8:
                    pred = tok.labels[6].value if len(tok.labels)>6 else "O"
                    score = tok.labels[6].score if len(tok.labels) > 6 else "1.0"
                else:
                    pred = tok.labels[6].value
                    score = str(tok.labels[6].score)
                preds.append(pred)
                scores.append(score)
                words.append(tok.text)
                tok.clear_embeddings()  # Without this, there will be an OOM issue

        toknum = 0
        output = []
        #out_format="diff"
        for i, sent in enumerate(sents):
            tid=1
            if i>0 and out_format=="conllu":
                output.append("")
            for tok in sent.tokens:
                pred = preds[toknum]
                score = str(scores[toknum])
                if len(score)>5:
                    score = score[:5]
                if out_format == "conllu":
                    pred = pred if not pred == "O" else "_"
                    fields = [str(tid),tok.text,"_",pred,pred,"_","_","_","_","_"]
                    output.append("\t".join(fields))
                    tid+=1
                elif out_format == "xg":
                    output.append("\t".join([pred, tok.text, score]))
                elif out_format == "tt":
                    output.append("\t".join([tok.text, pred]))
                else:
                    true_tag = true_tags[toknum]
                    corr = "T" if true_tag == pred else "F"
                    output.append("\t".join([tok.text, pred, true_tag, corr, score]))
                toknum += 1

        if as_text:
            return "\n".join(output)
        else:
            ext = "gme.conllu" if out_format == "conllu" else "txt"
            partition = "test" if "test" in in_path else "dev"
            with io.open(script_dir + TRAIN_PATH +os.sep + "flair-"+partition+"-pred." + ext,'w',encoding="utf8",newline="\n") as f:
                f.write("\n".join(output))
alanakbik commented 4 months ago

Hello @xiulinyang it is hard to tell without a runnable example to reproduce the error. But I would suspect the error lies somewhere in the way you read the sentences.

Some ideas:

xiulinyang commented 4 months ago

Hi @alanakbik, thanks a lot for your prompt reply! :)

I tried to use ColumnCorpus in Flair (as seen below) but the issue remains. Is it possible for me to send you the code, training log and the data for reproduce the error? Or is there anything that I misunderstood in the code below? Thanks!

   def predict(self, in_path=None, as_text=False):
        model = self.model
        if as_text:
            data = in_path
            #data = (data + "\n").replace("<s>\n", "").replace("</s>\n", "\n").strip()
        else:
            data = io.open(in_path,encoding="utf8").read()
        true_tags = []

        data = data.strip() + "\n"  # Ensure final new line for last sentence
        for line in data.split("\n"):
            if "\t" in line:
                true_tags.append(line.split("\t")[7])

        columns = {0: "text", 1: "upos", 2: "deprel", 3: "attach", 4: "arg1", 5: "arg2", 6: "arg3", 7: "ner"}

        corpus: Corpus = ColumnCorpus(data_folder, columns, train_file="train.biodep.sample.tab",
                                      test_file="test.biodep.sample.tab", dev_file="dev.biodep.sample.tab", )

        output = []
        toknum = 0
        model.predict(corpus.dev, force_token_predictions=True, return_probabilities_for_all_classes=True))

        for sentence in corpus.dev:
            for token in sentence:

                text = token.text
                if len(token.labels) > 6:
                    ner_tag = token.labels[6].value
                    ner_value = token.labels[6].score
                else:
                    ner_tag = "O"
                    ner_value = "1.0"

                corr = 'T' if ner_tag == true_tags[toknum] else 'F'
                output.append(text+'\t'+true_tags[toknum]+'\t'+ner_tag+ '\t'+ corr +'\t' + ner_value)
                toknum +=1

        partition = "test" if "test" in in_path else "dev"
        with io.open(script_dir + TRAIN_PATH +os.sep + "flair-"+partition+"-pred.tsv",'w',encoding="utf8",newline="\n") as f:
            f.write("\n".join(output))
xiulinyang commented 4 months ago

Hi, just a quick update. I tried the other task (upos tagging) but the result is still weird. I created a Google Colab script to replicate the experiment. It will only take 5 minutes to run. I attached the data (it only contains 100 sentences). It would be very much appreciated if you could offer some insight into potential problems. Thank you! :) tagger_new.zip

helpmefindaname commented 4 months ago

Hi @xiulinyang I am sorry, but that colab script contains too much code that doesn't seem to be related to flair. The concept of a minimal reproducible example is that you use as little code as possible to show that the error exists. I am not fully grasping what the issue is, nor am I convinced that this is a problem related to flair, as there are so many other factors that I cannot judge if done right or not.

xiulinyang commented 4 months ago

Hi @helpmefindaname,

Thank you very much for your reply! Sorry, I have been debugging for a while and the code was a mess. Now I have cleaned the code and you can run the experiment again with the data from the tagger folder . It contains 500 sentences for training and 100 sentences from these 500 examples for test. (I tried to downsample the data, but with smaller size, the model will only give 0 accuracy score).

The main problem I have right now is that the model won't give consistent predictions. After training, a test.tsv file will be generated automatically, and when I use tagger.evaluate(), another file named prediction.txt will be generated, too. I also manually write a .predict() method which generates flair-test-pred.txt. When checking these three files, they always give me different results. The difference seems trivial (the difference in terms of accuracy among these three files is less than 2%, but when I trained with the whole dataset containing ~10k examples, the difference was very large (0.88 vs 0.3).

Thanks!

xiulinyang commented 2 months ago

Hi, could you please offer some help? Thanks!

helpmefindaname commented 2 months ago

hi @xiulinyang There is still no minimal example to reproduce, so I am not sure what help I can provide. I would still suggest, that you take your example and try to strip of everything that is not related to flair or to demonstrate this issue.

xiulinyang commented 2 months ago

@helpmefindaname Hi, sorry, I have removed the unrelated code and only kept what is relevant to flair. I hope this time it works. Thanks!