huggingface / peft

🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning.
https://huggingface.co/docs/peft
Apache License 2.0
16.41k stars 1.61k forks source link

Add Assertions for `task_type` in `LoraConfig` #2203

Open d-kleine opened 5 days ago

d-kleine commented 5 days ago

Feature request

The current implementation of LoraConfig in the PEFT library does not validate whether the provided task_type is valid. This can lead to silent failures when users accidentally provide a misspelled or unsupported task type. To improve usability and error handling, I propose adding assertions or validation checks to ensure that the task_type parameter is one of the supported types.

Supported task types include:

If an invalid or misspelled task type is provided, the system should raise a clear error message instead of proceeding silently.

Add a validation check in the LoraConfig class to ensure that the provided task_type is one of the supported types. If not, raise a ValueError with a clear message indicating that the task type is invalid and listing the supported types.

Example implementation:

class LoraConfig(PeftConfig):
    def __init__(self, task_type: str, **kwargs):
        super().__init__(**kwargs)
        valid_task_types = [
            "SEQ_CLS", "SEQ_2_SEQ_LM", "CAUSAL_LM", 
            "TOKEN_CLS", "QUESTION_ANS", "FEATURE_EXTRACTION"
        ]
        if task_type not in valid_task_types:
            raise ValueError(
                f"Invalid task_type '{task_type}'. Supported task types are: {valid_task_types}"
            )
        self.task_type = task_type

Motivation

This feature request addresses a usability problem. Currently, if a user provides an invalid or misspelled task_type, no error is raised, and the model may behave unexpectedly without any indication of what went wrong. This can lead to confusion and wasted time during debugging. By adding assertions for valid task types, we can prevent such issues and provide immediate feedback to users, improving the overall developer experience.

For example, if a user mistakenly sets task_type="CASUAL_LM" instead of the correct CAUSAL_LM, no error is raised, and the model proceeds with incorrect configurations. This results in performance metrics of the model not being computed correctly and will be output as None (=empty).

Your contribution

I am willing to contribute by submitting a Pull Request (PR) with this feature, but need a confirmation before that is feature is wanted.

Test code for verification of the issue:

from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch

# Initialize model with bidirectional attention
label_list = ['B-LOC', 'B-MISC', 'B-ORG', 'B-PER', 'I-LOC', 'I-MISC', 'I-ORG', 'I-PER', 'O']

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

model = AutoModelForTokenClassification.from_pretrained(
    "meta-llama/Llama-3.2-1B",
    pad_token_id=tokenizer.eos_token_id,
    torch_dtype=torch.bfloat16,
    device_map="auto", 
    num_labels=len(label_list)
)

# Define LoRA configuration
lora_config = LoraConfig(
    task_type="blabla",      # Non-sense task type definition does not raise an error message
    r=16,             
    lora_alpha=32, 
    target_modules=["q_proj", "v_proj"],  
    lora_dropout=0.1
)

# Apply LoRA to the model
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
BenjaminBossan commented 3 days ago

Thank you for this proposal and for offering to create a PR. I agree that checking the value should be an overall improvement. If you work on a PR, please don't hard code the possible values but take them from the enum:

https://github.com/huggingface/peft/blob/b1fd97dc3e71d91abf9a8605e8be74ea6bb751c6/src/peft/utils/peft_types.py#L70-L89

Also, this applies to all PEFT methods, not just LoRA, so ideally the solution would solve this for all these methods.