Closed janpf closed 2 months ago
Hello @janpf could you provide a small test script how to train a model (for instance for NER) using PEFT? That would make it easier to test.
will do. hopefully this week :)
Ok, I got a minimal example. I adapted this: https://flairnlp.github.io/docs/tutorial-training/how-to-train-text-classifier
requirements.txt
:
git+https://github.com/flairNLP/flair.git@refs/pull/3480/merge
bitsandbytes
peft
scipy==1.10.1
The code then looks like this:
from flair.data import Corpus
from flair.datasets import TREC_6
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
corpus: Corpus = TREC_6()
label_type = "question_class"
label_dict = corpus.make_label_dictionary(label_type=label_type)
# this is new
from peft import LoraConfig, TaskType
import torch
import bitsandbytes as bnb
# set the quantization config (bitsandbytes)
bnb_config = {
"device_map": "auto",
"load_in_8bit": True,
}
# set lora config (peft)
peft_config = LoraConfig(
task_type=TaskType.FEATURE_EXTRACTION,
inference_mode=False,
)
document_embeddings = TransformerDocumentEmbeddings(
"uklfr/gottbert-base",
fine_tune=True,
# pass both configs using the newly introduced kwargs
transformers_model_kwargs=bnb_config,
peft_config=peft_config,
)
classifier = TextClassifier(
document_embeddings, label_dictionary=label_dict, label_type=label_type
)
trainer = ModelTrainer(classifier, corpus)
trainer.fine_tune(
"resources/taggers/question-classification-with-transformer",
learning_rate=5.0e-5,
mini_batch_size=4,
# i believe explicitly swapping out the optimizer is recommended
optimizer=bnb.optim.adamw.AdamW,
max_epochs=1,
the resulting model is quite bad, but all QLoRA-hyperparameters have been kept at the original values.
the logs then also show that the model has been correctly quantised (lora.Linear8bitLt
) and the adapters have been inserted (lora_A
& lora_B
):
2024-07-03 16:45:06,535 Model: "TextClassifier(
(embeddings): TransformerDocumentEmbeddings(
(model): PeftModelForFeatureExtraction(
(base_model): LoraModel(
(model): RobertaModel(
(embeddings): RobertaEmbeddings(
(word_embeddings): Embedding(52010, 768)
(position_embeddings): Embedding(514, 768, padding_idx=1)
(token_type_embeddings): Embedding(1, 768)
(LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): RobertaEncoder(
(layer): ModuleList(
(0-11): 12 x RobertaLayer(
(attention): RobertaAttention(
(self): RobertaSelfAttention(
(query): lora.Linear8bitLt(
(base_layer): Linear8bitLt(in_features=768, out_features=768, bias=True)
(lora_dropout): ModuleDict(
(default): Identity()
)
(lora_A): ModuleDict(
(default): Linear(in_features=768, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=768, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
)
(key): Linear8bitLt(in_features=768, out_features=768, bias=True)
(value): lora.Linear8bitLt(
(base_layer): Linear8bitLt(in_features=768, out_features=768, bias=True)
(lora_dropout): ModuleDict(
(default): Identity()
)
(lora_A): ModuleDict(
(default): Linear(in_features=768, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=768, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): RobertaSelfOutput(
(dense): Linear8bitLt(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): RobertaIntermediate(
(dense): Linear8bitLt(in_features=768, out_features=3072, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): RobertaOutput(
(dense): Linear8bitLt(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): RobertaPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
)
)
)
(decoder): Linear(in_features=768, out_features=6, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(locked_dropout): LockedDropout(p=0.0)
(word_dropout): WordDropout(p=0.0)
(loss_function): CrossEntropyLoss()
(weights): None
(weight_tensor) None
)"
and also: 2024-07-03 16:45:06,517 trainable params: 294,912 || all params: 126,279,936 || trainable%: 0.2335
Hi @janpf this looks good. I tested for a standard BERT model (for which quantization seems not to be available), and I'm getting competitive results to full fine-tuning when setting a slightly higher learning rate for LoRA:
from peft import LoraConfig, TaskType
document_embeddings = TransformerDocumentEmbeddings(
"bert-base-uncased",
fine_tune=True,
# set LoRA config
peft_config=LoraConfig(
task_type=TaskType.FEATURE_EXTRACTION,
inference_mode=False,
),
)
classifier = TextClassifier(document_embeddings, label_dictionary=label_dict, label_type=label_type)
trainer = ModelTrainer(classifier, corpus)
trainer.fine_tune(
"resources/taggers/question-classification-with-transformer",
learning_rate=5.0e-4,
mini_batch_size=4,
max_epochs=1,
)
Unfortunately, I don't know what is causing the storage error. This is now affecting all PRs.
Thanks again for adding this @janpf! Since the tests are now running through, we can merge!
This PR adds the ability to train models using PEFT (LoRA and QLoRA) and some nicer handling for model and config explicit kwargs. For example, passing through kwargs to the model but not the config was not possible before.
If PEFT is not installed and not used, no error should be thrown either.