flairNLP / flair

A very simple framework for state-of-the-art Natural Language Processing (NLP)
https://flairnlp.github.io/flair/
Other
13.94k stars 2.11k forks source link

[Question]: How to Train a Multi-label Text Classifier? #3255

Closed None-Such closed 1 year ago

None-Such commented 1 year ago

Objective

I am trying to use Flair to replicate the Kaggle Toxic Comment Classification Challenge which seeks to identify and classify toxic online comments.

[https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge]()

Approach

To do this I started with the Flair Tutorial: Train a text classifier

https://flairnlp.github.io/docs/tutorial-training/how-to-train-text-classifier

I made 2 minor changes to accomodate the Kaggle Challenge training data:

ClassificationCorpus     -> allow_examples_without_labels=True
TextClassifier       -> multi_label=True

> see code below

Clarification

To avoid any possible confusion, let me clarify one subtle aspect of Multi-label classification:

The Flair 'Train a text classifier' Tutorial training data (Trec 6) has:

    1 intended class (label) per question

The Kaggle Challenge the training data has:

    0 to N intended classes (labels) per comment

So the Kaggle data requires multi_label=True, unlike the TREC_6 data in the Flair Tutorial,

Performance

However, as far as I can tell, 'allow_examples_without_labels=True' does not work . . . as it causes inference to entirely fail for me =(

To work around this, I tagged all unlabelled records as 'benign'

I proceeded to do multiple runs using different models (using Flair 0.12.2), but I got strange behavior both at training time and inference time which differs based on the training model used:

'distilbert-base-uncased'   -> Accuracy 0.959 - Sort of works, F1 jumps to ~85% but then plateaus. Inference is very week for anything that is not 'toxic' or 'benign'
'roberta-base'          -> Accuracy 0.897 - F1 jumps to ~85% but then becomes static. Every inferred sentence gets almost exactly the same scores =(
'xlm-roberta-base'      -> Accuracy 0.937 - F1 jumps to ~85% but only detects 3 of 7 classes
'xlnet-base-cased'      -> Unable to load into a 24GB GPU even with a batch size of 2 ! =(
'DocumentRNNEmbeddings'     -> Overall, Flair DocumentRNNEmbeddings seemed to do the best and gave no strange behavior =)

> see F1 Scores below

The Kaggle contest winner had a score of: 0.98856. Interesting that the winner's approach seems to align with Flair Stacked Embeddings: https://www.kaggle.com/competitions/jigsaw-toxic-comment-classification-challenge/discussion/52557

Questions

Based on reviewing related GitHub Issues I have the following questions:

Question # 1 - Is my code below correct?

1a. Is there any thing wrong with my simple adaptation of the text classifier tutorial (code below) given I am targeting the Kaggle Toxic Comment multi-label data?

Question # 2 - Does the performance make sense?

2a. Is it possible for distilbert-base-uncased to provide a higher accuracy than roberta-base? 2b. Why would the roberta-base and xlm-roberta-base models Epoch level F1-Scores seem to get stuck after the 1st Epoch for? 2c. Why would the roberta-base and xlm-roberta-base models get zero for some classes in the test set (see Model Specific Results at bottom)?

Question # 3 - @alanakbik made the comment 'everything seems to be working . . . with our multi-label datasets', in Issue: https://github.com/flairNLP/flair/issues/678#issuecomment-485863025

3a. Which multi-label dataset was @alanakbik referring to in the that issue? 3b. And what settings are used to run it?

Question # 4 - @alanrios2001 made the comment 'training with torch's Adam optimizer, using MADGRAD the f1-score just work's fine..' in Issue: https://github.com/flairNLP/flair/issues/678#issuecomment-1526665624.

4a. How does one set an alternate optimizer when using Flair?

Question # 5 - Guidance for setting the learning_rate?

5a. Is there any model specific guidance on setting the learning_rate when fine-tuning a transformer model? 5b. Any general guidance?

Question # 6 - @helpmefindaname mentioned adjusting the loss_weights in Issue: https://github.com/flairNLP/flair/issues/2869#issuecomment-1191657808

6a. Is this an option worth pursuing? 6b. If so, what are reasonable weights?

    self.model = TextClassifier(document_embeddings,
                                    label_type=label_type,
                                    label_dictionary=label_dict,
                                    multi_label=True,
                                     loss_weights={ "label1": 3, "label2": 4, .... })

    self.model.loss_function = torch.nn.BCEWithLogitsLoss(pos_weight=self.model.loss_weights)

Code

Adapted from Flair Tutorial "Train a text classifier"


# Flair v0.12.2

from flair.data import Corpus
from flair.datasets import ClassificationCorpus
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer

# 1. get the corpus
# corpus: Corpus = TREC_6()
corpus: Corpus = ClassificationCorpus('/media/ubuntu/Drive/workspace/data/training/input/flair/Kaggle-Toxicity/',
                                      train_file='kaggle-toxic-comment-in-fasttext-format.txt',
                                      label_type='class'
                                      # test_file='test.txt',
                                      # dev_file='dev.txt',
                                      # allow_examples_without_labels=True
                                      )

# 2. what label do we want to predict?
# label_type = 'question_class'
label_type='class'

# 3. create the label dictionary
label_dict = corpus.make_label_dictionary(label_type=label_type)

# 4. initialize transformer document embeddings (many models are available)
document_embeddings = TransformerDocumentEmbeddings('xlm-roberta-base',    # 'distilbert-base-uncased' = 4  'roberta-base' 'xlm-roberta-base' 'xlnet-base-cased'
                                                    fine_tune=True)

# 5. create the text classifier
classifier = TextClassifier(document_embeddings,
                            label_dictionary=label_dict,
                            label_type=label_type,
                            multi_label=True)

# 6. initialize trainer
trainer = ModelTrainer(classifier, corpus)

# 7. run training with fine-tuning
trainer.fine_tune('resources/taggers/question-classification-with-transformer',
                  learning_rate=5.0e-5,
                  mini_batch_size=24,  # distilbert-base-uncased: 60    roberta-base: 24    xlm-roberta-base: 16
                  max_epochs=5)

Corpus Statistics

print(corpus.obtain_statistics())

 {
"TRAIN": {
    "dataset": "TRAIN",
    "total_number_of_documents": 129230,
    "number_of_documents_per_class": {
        "obscene": 6851,
        "insult": 6361,
        "threat": 383,
        "identity_hate": 1131,
        "toxic": 12420,
        "severe_toxic": 1300,
        "benign": 116057
    },
    "number_of_tokens_per_tag": {},
    "number_of_tokens": {
        "total": 10268628,
        "min": 2,
        "max": 2087,
        "avg": 79.46009440532384
    }
},
"TEST": {
    "dataset": "TEST",
    "total_number_of_documents": 15954,
    "number_of_documents_per_class": {
        "obscene": 836,
        "insult": 800,
        "threat": 44,
        "identity_hate": 135,
        "toxic": 1490,
        "severe_toxic": 149,
        "benign": 14368
    },
    "number_of_tokens_per_tag": {},
    "number_of_tokens": {
        "total": 1292236,
        "min": 2,
        "max": 1411,
        "avg": 80.99761815218754
    }
},
"DEV": {
    "dataset": "DEV",
    "total_number_of_documents": 14359,
    "number_of_documents_per_class": {
        "obscene": 762,
        "insult": 716,
        "threat": 51,
        "identity_hate": 139,
        "toxic": 1383,
        "severe_toxic": 146,
        "benign": 12894
    },
    "number_of_tokens_per_tag": {},
    "number_of_tokens": {
        "total": 1146079,
        "min": 2,
        "max": 1425,
        "avg": 79.81607354272582
    }
}

}

Model Specific Results

distilbert-base-uncased

Results:
- F-score (micro) 0.9651
- F-score (macro) 0.8729
- Accuracy 0.959

By class:
               precision    recall  f1-score   support

       benign     0.9808    0.9861    0.9834     14303
        toxic     0.8670    0.8308    0.8485      1554
      obscene     0.9488    0.9595    0.9541       888
       insult     0.9511    0.9396    0.9453       828
identity_hate     0.8973    0.7706    0.8291       170
 severe_toxic     0.8243    0.7625    0.7922       160
       threat     0.7800    0.7358    0.7573        53

roberta-base

Results:
- F-score (micro) 0.8469
- F-score (macro) 0.1352
- Accuracy 0.8977

By class:
               precision    recall  f1-score   support

       benign     0.8977    1.0000    0.9461     14323
        toxic     0.0000    0.0000    0.0000      1527
      obscene     0.0000    0.0000    0.0000       860
       insult     0.0000    0.0000    0.0000       799
 severe_toxic     0.0000    0.0000    0.0000       169
identity_hate     0.0000    0.0000    0.0000       143
       threat     0.0000    0.0000    0.0000        46

    micro avg     0.8977    0.8016    0.8469     17867
    macro avg     0.1282    0.1429    0.1352     17867
 weighted avg     0.7196    0.8016    0.7584     17867
  samples avg     0.8977    0.8977    0.8977     17867  

xlm-roberta-base

Results:
- F-score (micro) 0.9482
- F-score (macro) 0.5021
- Accuracy 0.9375

By class:
               precision    recall  f1-score   support

       benign     0.9671    0.9954    0.9810     14368
        toxic     0.9285    0.7148    0.8077      1490
      obscene     0.8086    0.9402    0.8695       836
       insult     0.9525    0.7775    0.8562       800
 severe_toxic     0.0000    0.0000    0.0000       149
identity_hate     0.0000    0.0000    0.0000       135
       threat     0.0000    0.0000    0.0000        44

    micro avg     0.9552    0.9413    0.9482     17822
    macro avg     0.5224    0.4897    0.5021     17822
 weighted avg     0.9380    0.9413    0.9377     17822
  samples avg     0.9576    0.9544    0.9555     17822  

   Results:
   - F-score (micro) 0.9655
   - F-score (macro) 0.8539
   - Accuracy 0.9586

DocumentRNNEmbeddings ([WordEmbeddings('glove'),FlairEmbeddings('news-forward'),FlairEmbeddings('news-backward')], hidden_size=512, reproject_words=True, reproject_words_dimension=256)

   By class:
                  precision    recall  f1-score   support

          benign     0.9732    0.9957    0.9843     14310
           toxic     0.9462    0.7703    0.8492      1554
         obscene     0.9621    0.9291    0.9453       875
          insult     0.9352    0.9387    0.9370       800
    severe_toxic     0.8151    0.7301    0.7702       163
   identity_hate     0.7852    0.7626    0.7737       139
          threat     0.8684    0.6111    0.7174        54

       micro avg     0.9661    0.9649    0.9655     17895
       macro avg     0.8979    0.8197    0.8539     17895
    weighted avg     0.9654    0.9649    0.9642     17895
     samples avg     0.9663    0.9661    0.9660     17895
helpmefindaname commented 1 year ago

Hi @None-Such

  1. The code looks fine
  2. as the Roberta-base is clearly fitting to only 1 class and not really learning, it makes totally sense. In general, the No free lunch theorem still holds for pretrained models.
  3. a) in that post he mentions set of internal problems meaning that the data is not publicly available.
  4. you can set the optimizer to use by using trainer.train(..., optimizer=<optimizer_cls>) or trainer.fine_tune(..., optimizer=<optimizer_cls>) respectively. For madgrad you have to find an implementation or code it down yourself.
  5. there exist plenty of blog posts on the internet like this one the tldr is: either grid search over a set of values or use a learning rate finder.
  6. looking how skewed your dataset is, I assume that this would be essential to improve the macro-F1 score. For the values you have to try it out yourself, you can start by the inverse proportion of the sample amounts and then modify it and see where this goes.
None-Such commented 1 year ago

@helpmefindaname - Most grateful =)