flairNLP / flair

A very simple framework for state-of-the-art Natural Language Processing (NLP)
https://flairnlp.github.io/flair/
Other
13.85k stars 2.09k forks source link

Add PEFT training and explicit kwarg passthrough #3480

Closed janpf closed 2 months ago

janpf commented 3 months ago

This PR adds the ability to train models using PEFT (LoRA and QLoRA) and some nicer handling for model and config explicit kwargs. For example, passing through kwargs to the model but not the config was not possible before.

If PEFT is not installed and not used, no error should be thrown either.

alanakbik commented 3 months ago

Hello @janpf could you provide a small test script how to train a model (for instance for NER) using PEFT? That would make it easier to test.

janpf commented 3 months ago

will do. hopefully this week :)

janpf commented 3 months ago

Ok, I got a minimal example. I adapted this: https://flairnlp.github.io/docs/tutorial-training/how-to-train-text-classifier

requirements.txt:

git+https://github.com/flairNLP/flair.git@refs/pull/3480/merge
bitsandbytes
peft
scipy==1.10.1

The code then looks like this:

from flair.data import Corpus
from flair.datasets import TREC_6
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer

corpus: Corpus = TREC_6()
label_type = "question_class"
label_dict = corpus.make_label_dictionary(label_type=label_type)

# this is new
from peft import LoraConfig, TaskType
import torch
import bitsandbytes as bnb

# set the quantization config (bitsandbytes)
bnb_config = {
    "device_map": "auto",
    "load_in_8bit": True,
}
# set lora config (peft)
peft_config = LoraConfig(
    task_type=TaskType.FEATURE_EXTRACTION,
    inference_mode=False,
)
document_embeddings = TransformerDocumentEmbeddings(
    "uklfr/gottbert-base",
    fine_tune=True,
    # pass both configs using the newly introduced kwargs
    transformers_model_kwargs=bnb_config,
    peft_config=peft_config,
)

classifier = TextClassifier(
    document_embeddings, label_dictionary=label_dict, label_type=label_type
)
trainer = ModelTrainer(classifier, corpus)
trainer.fine_tune(
    "resources/taggers/question-classification-with-transformer",
    learning_rate=5.0e-5,
    mini_batch_size=4,
    # i believe explicitly swapping out the optimizer is recommended
    optimizer=bnb.optim.adamw.AdamW,
    max_epochs=1,

the resulting model is quite bad, but all QLoRA-hyperparameters have been kept at the original values. the logs then also show that the model has been correctly quantised (lora.Linear8bitLt) and the adapters have been inserted (lora_A & lora_B):

2024-07-03 16:45:06,535 Model: "TextClassifier(
  (embeddings): TransformerDocumentEmbeddings(
    (model): PeftModelForFeatureExtraction(
      (base_model): LoraModel(
        (model): RobertaModel(
          (embeddings): RobertaEmbeddings(
            (word_embeddings): Embedding(52010, 768)
            (position_embeddings): Embedding(514, 768, padding_idx=1)
            (token_type_embeddings): Embedding(1, 768)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (encoder): RobertaEncoder(
            (layer): ModuleList(
              (0-11): 12 x RobertaLayer(
                (attention): RobertaAttention(
                  (self): RobertaSelfAttention(
                    (query): lora.Linear8bitLt(
                      (base_layer): Linear8bitLt(in_features=768, out_features=768, bias=True)
                      (lora_dropout): ModuleDict(
                        (default): Identity()
                      )
                      (lora_A): ModuleDict(
                        (default): Linear(in_features=768, out_features=8, bias=False)
                      )
                      (lora_B): ModuleDict(
                        (default): Linear(in_features=8, out_features=768, bias=False)
                      )
                      (lora_embedding_A): ParameterDict()
                      (lora_embedding_B): ParameterDict()
                    )
                    (key): Linear8bitLt(in_features=768, out_features=768, bias=True)
                    (value): lora.Linear8bitLt(
                      (base_layer): Linear8bitLt(in_features=768, out_features=768, bias=True)
                      (lora_dropout): ModuleDict(
                        (default): Identity()
                      )
                      (lora_A): ModuleDict(
                        (default): Linear(in_features=768, out_features=8, bias=False)
                      )
                      (lora_B): ModuleDict(
                        (default): Linear(in_features=8, out_features=768, bias=False)
                      )
                      (lora_embedding_A): ParameterDict()
                      (lora_embedding_B): ParameterDict()
                    )
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (output): RobertaSelfOutput(
                    (dense): Linear8bitLt(in_features=768, out_features=768, bias=True)
                    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
                (intermediate): RobertaIntermediate(
                  (dense): Linear8bitLt(in_features=768, out_features=3072, bias=True)
                  (intermediate_act_fn): GELUActivation()
                )
                (output): RobertaOutput(
                  (dense): Linear8bitLt(in_features=3072, out_features=768, bias=True)
                  (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
              )
            )
          )
          (pooler): RobertaPooler(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (activation): Tanh()
          )
        )
      )
    )
  )
  (decoder): Linear(in_features=768, out_features=6, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
  (locked_dropout): LockedDropout(p=0.0)
  (word_dropout): WordDropout(p=0.0)
  (loss_function): CrossEntropyLoss()
  (weights): None
  (weight_tensor) None
)"

and also: 2024-07-03 16:45:06,517 trainable params: 294,912 || all params: 126,279,936 || trainable%: 0.2335

alanakbik commented 2 months ago

Hi @janpf this looks good. I tested for a standard BERT model (for which quantization seems not to be available), and I'm getting competitive results to full fine-tuning when setting a slightly higher learning rate for LoRA:

from peft import LoraConfig, TaskType

document_embeddings = TransformerDocumentEmbeddings(
    "bert-base-uncased",
    fine_tune=True,
    # set LoRA config
    peft_config=LoraConfig(
        task_type=TaskType.FEATURE_EXTRACTION,
        inference_mode=False,
    ),
)

classifier = TextClassifier(document_embeddings, label_dictionary=label_dict, label_type=label_type)
trainer = ModelTrainer(classifier, corpus)
trainer.fine_tune(
    "resources/taggers/question-classification-with-transformer",
    learning_rate=5.0e-4,
    mini_batch_size=4,
    max_epochs=1,
)

Unfortunately, I don't know what is causing the storage error. This is now affecting all PRs.

alanakbik commented 2 months ago

Thanks again for adding this @janpf! Since the tests are now running through, we can merge!