Set the int4 field in the model's config to enable QLoRA fine-tuning.
The rest is consistent with basic task fine-tuning.
Modifications to the model structure:
Add a bool type field int4 in the model parameter files in the folder src/config, which acts as a switch to control whether to use QLoRA. Corresponding adjustments need to be made in other relevant structures (Attention/SelfAttentionBlock/FFNBlock/TransformerBlock/DenseGatedACT/FeedForward/Encoder/CPMBee) to load the appropriate models based on the int4 field.
In src/cpm_live/layers/feedforward.py, add class Linear4bit as the QLoRA method linear layer; add class Params4bit as the weight for Linear4bit; add class DistributedParameter4Int8 to meet encapsulation needs.
Add scripts/sample code/README:
src/quantize_state_dict.py is the code for compressing the initial weights. QLoRA needs to load the compressed dict as model weights.
src/finetune_cpm_bee_qlora.py is the fine-tuning sample code.
src/scripts/finetune_cpm_bee_qlora.sh is the fine-tuning sample script.
tutorials/basic_task_finetune/README_qlora.md is the fine-tuning tutorial for QLoRA.
Other considerations:
The inspect part of the code has been commented out in src/finetune_cpm_bee_qlora.py, as uint8 does not support std and var.
It's necessary to synchronize and modify the bug in BMTrain.blocklayer where uint8 type requires_grad cannot be passed in.
This PR mainly involves the following aspects:
QLoRA overall logic:
Modifications to the model structure:
int4
in the model parameter files in the foldersrc/config
, which acts as a switch to control whether to use QLoRA. Corresponding adjustments need to be made in other relevant structures (Attention
/SelfAttentionBlock
/FFNBlock
/TransformerBlock
/DenseGatedACT
/FeedForward
/Encoder
/CPMBee
) to load the appropriate models based on the int4 field.src/cpm_live/layers/feedforward.py
, add classLinear4bit
as the QLoRA method linear layer; add classParams4bit
as the weight forLinear4bit
; add classDistributedParameter4Int8
to meet encapsulation needs.Add scripts/sample code/README:
src/quantize_state_dict.py
is the code for compressing the initial weights. QLoRA needs to load the compressed dict as model weights.src/finetune_cpm_bee_qlora.py
is the fine-tuning sample code.src/scripts/finetune_cpm_bee_qlora.sh
is the fine-tuning sample script.tutorials/basic_task_finetune/README_qlora.md
is the fine-tuning tutorial for QLoRA.Other considerations:
src/finetune_cpm_bee_qlora.py
, asuint8
does not supportstd
andvar
.BMTrain.blocklayer
whereuint8
typerequires_grad
cannot be passed in.