Closed d-kleine closed 1 week ago
@BenjaminBossan What do you think of this lean implementation?
@d-kleine Thanks for the PR. Right now, I'm at a company offsite and don't have the means to do proper reviews. I'll probably review at the start of next week.
I see what you mean. I have implemented the changes, they work good so far, returning an error when setting up the config with either no task_type
is passed (None
) or the one passed does not match with the task types defined in the mappings.
About the super().__post_init__()
, I agree this is also a good idea for this project in general. But I have noticed most parents (e.g. PeftConfig
, PromptLearningConfig
) don't have a __post_init__()
method yet, thus calling super().__post_init__()
in the def __post_init__(self):
in would result in an error message then
parent:
@dataclass
class PeftConfig(PeftConfigMixin):
... # no `def __post_init__(self):` here
child:
@dataclass
class LoraConfig(PeftConfig):
...
def __post_init__(self):
super().__post_init__() # error as not defined in parent
So, what to do? Something like this? ⬇️
def __post_init__(self):
if hasattr(super(), '__post_init__'):
super().__post_init__()
Hmm, I think this check for the presence of __post_init__
should not be necessary. I tried this locally:
# peft/config.py in PeftConfigMixin
def __post_init__(self):
if (self.task_type is not None) and (self.task_type not in list(TaskType)):
raise ValueError(f"Invalid task type: '{self.task_type}'. Must be one of the following task types: {', '.join(TaskType)}.")
# peft/tuners/lora/config.py
def __post_init__(self):
super().__post_init__()
... # rest of the code
Then running:
from peft import LoraConfig
lora_config_1 = LoraConfig(task_type=None)
lora_config_1 = LoraConfig(task_type="CAUSAL_LM")
lora_config_2 = LoraConfig(task_type="CAUSAL_LMX")
I get the expected error on the last config:
ValueError: Invalid task type: 'CAUSAL_LMX'. Must be one of the following task types: SEQ_CLS, SEQ_2_SEQ_LM, CAUSAL_LM, TOKEN_CLS, QUESTION_ANS, FEATURE_EXTRACTION.
Yeah, I mean when you add a __post_init__
to the parent, then of course the super().__post_init__()
in the child works. I was trying to say that there is almost no __post_init__
in the parents yet, but now I understood that you want me to create a __post_init__
containing the task type check in PeftConfigMixin
(which are all PEFT configs seem to be based on). Is that right?
you want me to create a
__post_init__
containing the task type check inPeftConfigMixin
(which are all PEFT configs seem to be based on). Is that right?
Yes, exactly, sorry for the confusion.
Alright, thank you!
Thanks to your sample code, I now understand that task_type
can be None
as a valid value. Initially, I thought setting task_type
to None
would be an invalid configuration and therefore should also raise a ValueError
. I have not worked without defining a task_type
yet, so are you sure this is a good idea not returning an error (or at least a warning) in this case?
I just have implemented the logic you have suggested and pushed the changes.
I have also noticed that you can import a TaskType
directly, like
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import TaskType
# Step 1: Load the base model and tokenizer
model_name_or_path = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# Step 2: Define the LoRA configuration
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj"],
bias="none",
task_type=TaskType.BLABLA # Invalid task type
)
This would raise an AttributeError:
AttributeError: BLABLA
Even though it's not as specific as above, I think this is fine. What do you think?
It should now raise an error if I'm not mistaken. So let's pass a valid task type here. Next, let's create a similar test with an invalid task type and check that the error you added is raised.
I have added the tests (one for valid task types, one for invalid task types), as requested. I have defined the valid task types as the ones defined in TaskType
plus None
. For the invalid task types, I have only used the provided example "test"
.
Furthermore, I have run the tests before (you were right, the tests would have failed). After implementing the changes in the tests, the unit tests will pass then:
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-AdaLoraConfig] PASSED [ 0%]
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-AdaptionPromptConfig] PASSED [ 1%]
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-BOFTConfig] PASSED [ 2%]
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-FourierFTConfig] PASSED [ 2%]
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-HRAConfig] PASSED [ 3%]
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-IA3Config] PASSED [ 4%]
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-LNTuningConfig] PASSED [ 5%]
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-LoHaConfig] PASSED [ 5%]
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-LoKrConfig] PASSED [ 6%]
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-LoraConfig] PASSED [ 7%]
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-MultitaskPromptTuningConfig] PASSED [ 8%]
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-PolyConfig] PASSED [ 8%]
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-PrefixTuningConfig] PASSED [ 9%]
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-PromptEncoderConfig] PASSED [ 10%]
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-PromptTuningConfig] PASSED [ 11%]
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-VeraConfig] PASSED [ 11%]
peft\tests\test_config.py::TestPeftConfig::test_invalid_task_type[test-VBLoRAConfig] PASSED [ 12%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-AdaLoraConfig] PASSED [ 13%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-AdaptionPromptConfig] PASSED [ 13%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-BOFTConfig] PASSED [ 14%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-FourierFTConfig] PASSED [ 15%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-HRAConfig] PASSED [ 16%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-IA3Config] PASSED [ 16%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-LNTuningConfig] PASSED [ 17%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-LoHaConfig] PASSED [ 18%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-LoKrConfig] PASSED [ 19%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-LoraConfig] PASSED [ 19%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-MultitaskPromptTuningConfig] PASSED [ 20%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-PolyConfig] PASSED [ 21%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-PrefixTuningConfig] PASSED [ 22%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-PromptEncoderConfig] PASSED [ 22%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-PromptTuningConfig] PASSED [ 23%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-VeraConfig] PASSED [ 24%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_CLS-VBLoRAConfig] PASSED [ 25%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-AdaLoraConfig] PASSED [ 25%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-AdaptionPromptConfig] PASSED [ 26%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-BOFTConfig] PASSED [ 27%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-FourierFTConfig] PASSED [ 27%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-HRAConfig] PASSED [ 28%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-IA3Config] PASSED [ 29%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-LNTuningConfig] PASSED [ 30%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-LoHaConfig] PASSED [ 30%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-LoKrConfig] PASSED [ 31%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-LoraConfig] PASSED [ 32%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-MultitaskPromptTuningConfig] PASSED [ 33%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-PolyConfig] PASSED [ 33%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-PrefixTuningConfig] PASSED [ 34%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-PromptEncoderConfig] PASSED [ 35%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-PromptTuningConfig] PASSED [ 36%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-VeraConfig] PASSED [ 36%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[SEQ_2_SEQ_LM-VBLoRAConfig] PASSED [ 37%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-AdaLoraConfig] PASSED [ 38%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-AdaptionPromptConfig] PASSED [ 38%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-BOFTConfig] PASSED [ 39%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-FourierFTConfig] PASSED [ 40%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-HRAConfig] PASSED [ 41%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-IA3Config] PASSED [ 41%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-LNTuningConfig] PASSED [ 42%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-LoHaConfig] PASSED [ 43%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-LoKrConfig] PASSED [ 44%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-LoraConfig] PASSED [ 44%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-MultitaskPromptTuningConfig] PASSED [ 45%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-PolyConfig] PASSED [ 46%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-PrefixTuningConfig] PASSED [ 47%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-PromptEncoderConfig] PASSED [ 47%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-PromptTuningConfig] PASSED [ 48%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-VeraConfig] PASSED [ 49%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[CAUSAL_LM-VBLoRAConfig] PASSED [ 50%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-AdaLoraConfig] PASSED [ 50%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-AdaptionPromptConfig] PASSED [ 51%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-BOFTConfig] PASSED [ 52%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-FourierFTConfig] PASSED [ 52%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-HRAConfig] PASSED [ 53%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-IA3Config] PASSED [ 54%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-LNTuningConfig] PASSED [ 55%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-LoHaConfig] PASSED [ 55%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-LoKrConfig] PASSED [ 56%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-LoraConfig] PASSED [ 57%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-MultitaskPromptTuningConfig] PASSED [ 58%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-PolyConfig] PASSED [ 58%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-PrefixTuningConfig] PASSED [ 59%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-PromptEncoderConfig] PASSED [ 60%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-PromptTuningConfig] PASSED [ 61%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-VeraConfig] PASSED [ 61%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[TOKEN_CLS-VBLoRAConfig] PASSED [ 62%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-AdaLoraConfig] PASSED [ 63%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-AdaptionPromptConfig] PASSED [ 63%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-BOFTConfig] PASSED [ 64%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-FourierFTConfig] PASSED [ 65%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-HRAConfig] PASSED [ 66%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-IA3Config] PASSED [ 66%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-LNTuningConfig] PASSED [ 67%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-LoHaConfig] PASSED [ 68%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-LoKrConfig] PASSED [ 69%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-LoraConfig] PASSED [ 69%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-MultitaskPromptTuningConfig] PASSED [ 70%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-PolyConfig] PASSED [ 71%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-PrefixTuningConfig] PASSED [ 72%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-PromptEncoderConfig] PASSED [ 72%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-PromptTuningConfig] PASSED [ 73%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-VeraConfig] PASSED [ 74%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[QUESTION_ANS-VBLoRAConfig] PASSED [ 75%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-AdaLoraConfig] PASSED [ 75%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-AdaptionPromptConfig] PASSED [ 76%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-BOFTConfig] PASSED [ 77%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-FourierFTConfig] PASSED [ 77%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-HRAConfig] PASSED [ 78%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-IA3Config] PASSED [ 79%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-LNTuningConfig] PASSED [ 80%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-LoHaConfig] PASSED [ 80%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-LoKrConfig] PASSED [ 81%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-LoraConfig] PASSED [ 82%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-MultitaskPromptTuningConfig] PASSED [ 83%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-PolyConfig] PASSED [ 83%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-PrefixTuningConfig] PASSED [ 84%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-PromptEncoderConfig] PASSED [ 85%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-PromptTuningConfig] PASSED [ 86%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-VeraConfig] PASSED [ 86%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[FEATURE_EXTRACTION-VBLoRAConfig] PASSED [ 87%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-AdaLoraConfig] PASSED [ 88%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-AdaptionPromptConfig] PASSED [ 88%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-BOFTConfig] PASSED [ 89%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-FourierFTConfig] PASSED [ 90%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-HRAConfig] PASSED [ 91%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-IA3Config] PASSED [ 91%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-LNTuningConfig] PASSED [ 92%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-LoHaConfig] PASSED [ 93%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-LoKrConfig] PASSED [ 94%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-LoraConfig] PASSED [ 94%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-MultitaskPromptTuningConfig] PASSED [ 95%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-PolyConfig] PASSED [ 96%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-PrefixTuningConfig] PASSED [ 97%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-PromptEncoderConfig] PASSED [ 97%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-PromptTuningConfig] PASSED [ 98%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-VeraConfig] PASSED [ 99%]
peft\tests\test_config.py::TestPeftConfig::test_valid_task_type[None-VBLoRAConfig] PASSED [100%]
All tests in the script will also pass:
If you need further changes / improvements, please let me know.
Yes, that's the nice thing about using
enum
s, you will immediately know if you made a typo and also your text editor will help you with auto-complete. I'd say using this is the "proper" way of doing it but it requires a bit more typing so users are often lazy and just enter the string (I'm guilty as well).
Yeah, I agree. I am using VS Code and the direct TaskType
import is indeed better:
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Great that the enum works as intended. Still, having this extra check for strings will be helpful overall.
Yeah, I fully agree 🙂
@d-kleine Thanks for the update, please run make style
too.
@d-kleine Thanks for the update, please run
make style
too.
Done
@d-kleine There a test is now failing which was indirectly caused by the changes of this PR. Long story short, please go to this line:
and change it to AdaLoraConfig(init_lora_weights="loftq", loftq_config={"loftq": "config"})
. LMK if you want me to explain why it's needed, but it's really just tangential to your PR.
Fixed and pushed. I was looking into the error too (it seems that the quantized state and its adapter initialization must be aligned for AdaLoraConfig
). But did this happen because of my PR?
Thanks for the merge!
Just as a side note, I’d like to suggest that the code might benefit from more detailed documentation. While working on the PR, I noticed that additional comments and more comprehensive docstrings could make it easier to navigate in the code, especially for those not maintaining the project regularly. This could help making the project more accessible to both users and contributors.
Fixes #2203
This PR adds validation for the
task_type
parameter across all PEFT types (e.g., LoRA, Prefix Tuning, etc.) to ensure that only valid task types, as defined in theTaskType
enum, are used. This resolves the issue where invalid or non-sense task types (e.g.,"BLABLA"
) or typos (e.g.,"CUASAL_LM"
instead of"CAUSAL_LM"
) could be passed without raising an error, potentially causing unexpected behavior.Why This Change Is Necessary:
task_type
, maintaining consistency across different PEFT configurations.Tests:
Example Scenario:
Before this PR, the following code would not raise an error despite using an invalid task type:
With this PR, attempting to use an invalid task type like
"BLABLA"
will now raise aValueError
, ensuring that only valid task types are accepted:For multiple PEFT methods:
INVALID
task_type
VALID
task_type