Closed andrewor14 closed 3 days ago
Note: Links to docs will display an error until the docs builds have been completed.
As of commit 95961d456d9fc5a07dd969be4dbbddd3a86fb1c1 with merge base abdb5a43c1173cdb05208ca6fd498919536c4c19 (): :green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Hey @andrewor14, I was hacking around with LoRA/QLoRA + INT8 mixed-precision and came across this PR of yours. I realized what we are trying to achieve is quite similar.
Components of LoRALinear
:
F.linear(x, self.weight)
-> we want to modify this opself.lora_b(self.lora_a(x))
Since the base weight is direct children of LoRALinear
, and F.linear()
is hard-coded, it's hard to extend functionality of LoRALinear without re-writing the whole thing. So I had the idea of making the base weight as its own nn.Linear()
module, thus we can freely swap the base linear module to modify its op.
I have a POC here https://github.com/pytorch/torchtune/compare/main...gau-nernst:qlora (you can focus on torchtune/modules/peft/lora.py
file). With this, we can re-use the linear module-swap in torchao. And you don't need a separate qat_lora_finetune_distributed.py
, since we can add quantizer to existing recipes (though I understand you might not want this. I didn't carefully check the differences of QAT recipe script from other training scripts).
Hi @gau-nernst, yeah I agree we can make the base weight more flexible, then we won't need to create a new class every time we need to extend lora functionality. cc @ebsmothers to see your thoughts on extending LoRALinear
this way: https://github.com/pytorch/torchtune/compare/main...gau-nernst:qlora. For QAT in particular though the current flow uses full module swap (no tensor subclass yet), so we'll need some other way to initialize the base module like manually setting self.base
, so it may not be as elegant there. Also I think the existing QATLoRALinear
doesn't add much boiler plate code, so it might be OK.
For the separate recipe, I discussed this with @ebsmothers recently and I think it's torchtune's recipe organization philosophy to keep them separate, so QAT functionality won't complicate the original lora recipe.
Attention: Patch coverage is 10.54852%
with 424 lines
in your changes missing coverage. Please review.
Project coverage is 24.40%. Comparing base (
1814feb
) to head (2404803
). Report is 7 commits behind head on main.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
🚨 Try these New Features:
@ebsmothers Any comments? Does this look good to you?
Hey @andrewor14 sorry for the delay and thanks for your patience here. We are doing planning this week so my available bandwidth for reviewing this has taken a hit. I promise to get to it by Friday at the latest
@ebsmothers Regarding
key names of the base linear weight will now have an extra module name in between for LoRALinear
In my proof-of-concept above, I handle this by adding the following hooks
def load_state_dict_pre_hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
if isinstance(module, LoRALinear):
state_dict[f"{prefix}base.weight"] = state_dict.pop(f"{prefix}weight")
self.register_load_state_dict_pre_hook(load_state_dict_pre_hook)
def state_dict_post_hook(module, state_dict, prefix, local_metadata):
if isinstance(module, LoRALinear):
state_dict[f"{prefix}weight"] = state_dict.pop(f"{prefix}base.weight")
self.register_state_dict_post_hook(state_dict_post_hook)
From my testing it seems sufficient, though I might not cover all edge cases (FSDP2?)
We can discuss more in a separate issue/PR if you are open to it, so as not to hijack this PR about QAT + LoRA 😄. The main benefit is ease of injecting custom logic, such as QAT for this PR, INT8 matmul for #1552, or even FP8 matmul in the future. You probably know better than me what are the potential issues, but I think we can try to see if those can be handled nicely.
@gau-nernst personally I have a bit of an aversion to state dict hooks as @pbontrager can attest 😅. Mainly I find that they make code really hard to debug. Correct usage of modules having state dict hooks generally requires that a module has its state dict called exactly once and submodules are not accessed or modified in any other way. And if either of these constraints are not satisfied the user will get a very non-obvious error about some missing attribute and it won't be at all clear where to go to fix it.
But I agree with your point about consolidating the discussion elsewhere (sounds like this PR wouldn't benefit as much from modifying LoRALinear
's self.weight
anyways). Maybe some lightweight RFC discussing pros and cons would be helpful (I can add my comments there as well), and we can tag other folks to get their thoughts too.
@ebsmothers any other comments?
OK a couple more small comments but after that I think this should be good to go. A couple other requests before landing:
- Can you make sure this works with all our usual features (e.g. activation checkpointing, activation offloading)? I already ran with compile myself so no need to worry about that one
- You should also add it to the recipes table in our readme! That way people will know to try it out
Sounds good. I think I addressed all of the comments and also tested it with the features you mentioned. Please take another look, thanks!
Summary:
This commit adds a recipe that combines QAT + LoRA, with the main goal of improving final quantized accuracy after training while reducing the memory required for fine-tuning. The new recipe
qat_lora_finetune_distributed
mirrors the existinglora_finetune_distributed
recipe, which performs only LoRA, and is analogous to the existingqat_distributed
recipe, which performs only QAT.Helpful code review commands:
For more context on QAT, please visit https://github.com/pytorch/torchtune/pull/980 and https://pytorch.org/blog/quantization-aware-training/.
Test Plan
Unit tests:
Manual tests:
Results: