pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.39k stars 448 forks source link

Add support for QAT + LoRA #1931

Closed andrewor14 closed 3 days ago

andrewor14 commented 1 month ago

Summary:

This commit adds a recipe that combines QAT + LoRA, with the main goal of improving final quantized accuracy after training while reducing the memory required for fine-tuning. The new recipe qat_lora_finetune_distributed mirrors the existing lora_finetune_distributed recipe, which performs only LoRA, and is analogous to the existing qat_distributed recipe, which performs only QAT.

Helpful code review commands:

diff --color recipes/lora_finetune_distributed.py recipes/qat_lora_finetune_distributed.py
diff --color recipes/configs/llama3/8B_lora.yaml recipes/configs/llama3/8B_qat_lora.yaml
diff --color recipes/configs/llama3_1/8B_lora.yaml recipes/configs/llama3_1/8B_qat_lora.yaml
diff --color recipes/configs/llama3_2/1B_lora.yaml recipes/configs/llama3_2/1B_qat_lora.yaml
diff --color recipes/configs/llama3_2/3B_lora.yaml recipes/configs/llama3_2/3B_qat_lora.yaml

For more context on QAT, please visit https://github.com/pytorch/torchtune/pull/980 and https://pytorch.org/blog/quantization-aware-training/.

Test Plan

Unit tests:

pytest -m integration_test tests/recipes/test_qat_lora_finetune_distributed.py

Manual tests:

export CUDA_VISIBLE_DEVICES=4,5,6,7
export NCCL_SHM_DISABLE=0
LOG_DIR=/home/andrewor/local/logs/tune/qat_lora

tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3/8B_qat_lora \
    batch_size=4 \
    quantizer.groupsize=32 \
    checkpointer.output_dir="$LOG_DIR" \
    metric_logger.output_dir="${LOG_DIR}/metrics"

tune run quantize --config quantization \
    model._component_=torchtune.models.llama3.llama3_8b \
    checkpointer._component_=torchtune.training.FullModelMetaCheckpointer \
    checkpointer.checkpoint_dir="$LOG_DIR" \
    checkpointer.output_dir="$LOG_DIR" \
    checkpointer.checkpoint_files=["meta_model_0.pt"] \
    checkpointer.model_type=LLAMA3 \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
    quantizer.groupsize=32

tune run eleuther_eval --config eleuther_evaluation \
    batch_size=1 \
    model._component_=torchtune.models.llama3.llama3_8b \
    checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
    checkpointer.checkpoint_dir="$LOG_DIR" \
    checkpointer.output_dir="$LOG_DIR" \
    checkpointer.checkpoint_files=["meta_model_0.pt-8da4w"] \
    checkpointer.model_type=LLAMA3 \
    tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
    tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
    tasks=[wikitext] \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
    quantizer.groupsize=32

Results:

# Baseline (LoRA only, no QAT)

| Tasks  |Version|Filter|n-shot|    Metric     |   | Value |   |Stderr|
|--------|------:|------|------|---------------|---|------:|---|------|
|wikitext|      2|none  |None  |bits_per_byte  |↓  | 0.6284|±  |   N/A|
|        |       |none  |None  |byte_perplexity|↓  | 1.5458|±  |   N/A|
|        |       |none  |None  |word_perplexity|↓  |10.2694|±  |   N/A|

# LoRA + QAT (new recipe)

| Tasks  |Version|Filter|n-shot|    Metric     |   | Value |   |Stderr|
|--------|------:|------|------|---------------|---|------:|---|------|
|wikitext|      2|none  |None  |bits_per_byte  |↓  | 0.6245|±  |   N/A|
|        |       |none  |None  |byte_perplexity|↓  | 1.5416|±  |   N/A|
|        |       |none  |None  |word_perplexity|↓  |10.1208|±  |   N/A|
pytorch-bot[bot] commented 1 month ago

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1931

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit 95961d456d9fc5a07dd969be4dbbddd3a86fb1c1 with merge base abdb5a43c1173cdb05208ca6fd498919536c4c19 (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

gau-nernst commented 3 weeks ago

Hey @andrewor14, I was hacking around with LoRA/QLoRA + INT8 mixed-precision and came across this PR of yours. I realized what we are trying to achieve is quite similar.

Components of LoRALinear:

Since the base weight is direct children of LoRALinear, and F.linear() is hard-coded, it's hard to extend functionality of LoRALinear without re-writing the whole thing. So I had the idea of making the base weight as its own nn.Linear() module, thus we can freely swap the base linear module to modify its op.

I have a POC here https://github.com/pytorch/torchtune/compare/main...gau-nernst:qlora (you can focus on torchtune/modules/peft/lora.py file). With this, we can re-use the linear module-swap in torchao. And you don't need a separate qat_lora_finetune_distributed.py, since we can add quantizer to existing recipes (though I understand you might not want this. I didn't carefully check the differences of QAT recipe script from other training scripts).

andrewor14 commented 3 weeks ago

Hi @gau-nernst, yeah I agree we can make the base weight more flexible, then we won't need to create a new class every time we need to extend lora functionality. cc @ebsmothers to see your thoughts on extending LoRALinear this way: https://github.com/pytorch/torchtune/compare/main...gau-nernst:qlora. For QAT in particular though the current flow uses full module swap (no tensor subclass yet), so we'll need some other way to initialize the base module like manually setting self.base, so it may not be as elegant there. Also I think the existing QATLoRALinear doesn't add much boiler plate code, so it might be OK.

For the separate recipe, I discussed this with @ebsmothers recently and I think it's torchtune's recipe organization philosophy to keep them separate, so QAT functionality won't complicate the original lora recipe.

codecov-commenter commented 3 weeks ago

Codecov Report

Attention: Patch coverage is 10.54852% with 424 lines in your changes missing coverage. Please review.

Project coverage is 24.40%. Comparing base (1814feb) to head (2404803). Report is 7 commits behind head on main.

Files with missing lines Patch % Lines
recipes/qat_lora_finetune_distributed.py 0.00% 311 Missing :warning:
...ests/recipes/test_qat_lora_finetune_distributed.py 32.60% 62 Missing :warning:
torchtune/modules/peft/lora.py 13.04% 40 Missing :warning:
tests/torchtune/modules/peft/test_lora.py 45.45% 6 Missing :warning:
torchtune/training/quantization.py 61.53% 5 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #1931 +/- ## =========================================== - Coverage 67.29% 24.40% -42.89% =========================================== Files 318 325 +7 Lines 17646 18498 +852 =========================================== - Hits 11874 4515 -7359 - Misses 5772 13983 +8211 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.


🚨 Try these New Features:

andrewor14 commented 2 weeks ago

@ebsmothers Any comments? Does this look good to you?

ebsmothers commented 2 weeks ago

Hey @andrewor14 sorry for the delay and thanks for your patience here. We are doing planning this week so my available bandwidth for reviewing this has taken a hit. I promise to get to it by Friday at the latest

gau-nernst commented 2 weeks ago

@ebsmothers Regarding

key names of the base linear weight will now have an extra module name in between for LoRALinear

In my proof-of-concept above, I handle this by adding the following hooks

        def load_state_dict_pre_hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
            if isinstance(module, LoRALinear):
                state_dict[f"{prefix}base.weight"] = state_dict.pop(f"{prefix}weight")

        self.register_load_state_dict_pre_hook(load_state_dict_pre_hook)

        def state_dict_post_hook(module, state_dict, prefix, local_metadata):
            if isinstance(module, LoRALinear):
                state_dict[f"{prefix}weight"] = state_dict.pop(f"{prefix}base.weight")

        self.register_state_dict_post_hook(state_dict_post_hook)

From my testing it seems sufficient, though I might not cover all edge cases (FSDP2?)

We can discuss more in a separate issue/PR if you are open to it, so as not to hijack this PR about QAT + LoRA 😄. The main benefit is ease of injecting custom logic, such as QAT for this PR, INT8 matmul for #1552, or even FP8 matmul in the future. You probably know better than me what are the potential issues, but I think we can try to see if those can be handled nicely.

ebsmothers commented 2 weeks ago

@gau-nernst personally I have a bit of an aversion to state dict hooks as @pbontrager can attest 😅. Mainly I find that they make code really hard to debug. Correct usage of modules having state dict hooks generally requires that a module has its state dict called exactly once and submodules are not accessed or modified in any other way. And if either of these constraints are not satisfied the user will get a very non-obvious error about some missing attribute and it won't be at all clear where to go to fix it.

But I agree with your point about consolidating the discussion elsewhere (sounds like this PR wouldn't benefit as much from modifying LoRALinear's self.weight anyways). Maybe some lightweight RFC discussing pros and cons would be helpful (I can add my comments there as well), and we can tag other folks to get their thoughts too.

andrewor14 commented 1 week ago

@ebsmothers any other comments?

andrewor14 commented 3 days ago

OK a couple more small comments but after that I think this should be good to go. A couple other requests before landing:

  1. Can you make sure this works with all our usual features (e.g. activation checkpointing, activation offloading)? I already ran with compile myself so no need to worry about that one
  2. You should also add it to the recipes table in our readme! That way people will know to try it out

Sounds good. I think I addressed all of the comments and also tested it with the features you mentioned. Please take another look, thanks!