pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.05k stars 370 forks source link

NF4 quantization of linear layers without LoRA applied #1119

Closed winglian closed 2 months ago

winglian commented 3 months ago

Context

What is the purpose of this PR? Is it to

Please link to any issues this PR addresses. #1093

Changelog

Reverts #658 to bring back FrozenNF4Linear. When quantize_base is set to true, all base weights for linear layers are quantized, even if they do not have LoRA applied to them.

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)

pytorch-bot[bot] commented 3 months ago

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1119

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit 44edc45e588b0b6486c1b7728bc9b22dfffabf75 with merge base 58255001bd0b1e3a81a6302201024e472af05379 (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

joecummings commented 3 months ago

cc @msaroufim

winglian commented 3 months ago

running tune run lora_finetune_single_device --config llama3/8B_qlora_single_device but only with lora_attn_modules: ['q_proj'] instead of all the attention linear layers, uses 15810MiB on main and 15000MiB on this branch.

EDIT: setting lora_attn_modules active across all 4 modules uses 15044MiB in this branch, which is expected to be more than the 15000MiB due to the optimizer, and also less than main because all the weights are quantized.

winglian commented 2 months ago

here's the evals. slight drop from quantization 0.5682 -> 0.5546

main

|    Tasks     |Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|--------------|------:|------|-----:|------|---|-----:|---|-----:|
|truthfulqa_mc2|      2|none  |     0|acc   |↑  |0.5682|±  |0.0149|

nf4 - quantized linear
|    Tasks     |Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|--------------|------:|------|-----:|------|---|-----:|---|-----:|
|truthfulqa_mc2|      2|none  |     0|acc   |↑  |0.5546|±  |0.0149|
ebsmothers commented 2 months ago

@winglian thanks for adding the eval results, I think it looks reasonable. One remaining thing before we merge: seems like one of the bnb comparison unit test cases is failing. See here