OpenGVLab / OmniQuant

[ICLR2024 spotlight] OmniQuant is a simple and powerful quantization technique for LLMs.
MIT License
626 stars 49 forks source link

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). #64

Closed zkf331 closed 4 months ago

zkf331 commented 4 months ago

I modified the code to support the Codellama-34b model, but when using lwc and let simultaneously, the following error occurred:

Traceback (most recent call last): File "main.py", line 380, in <module> main() File "main.py", line 345, in main omniquant( File "~/OmniQuant/OmniQuant-main/quantize/omniquant.py", line 358, in omniquant norm = loss_scaler(loss, optimizer, File "~/OmniQuant/OmniQuant-main/utils.py", line 34, in __call__ self._scaler.scale(loss).backward(create_graph=create_graph, retain_graph=retain_graph) File "~/aconda/envs/omiquant/lib/python3.8/site-packages/torch/_tensor.py", line 522, in backward torch.autograd.backward( File "~/aconda/envs/omiquant/lib/python3.8/site-packages/torch/autograd/__init__.py", line 266, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermed iate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

What could be the problem? Is there any good solution?

ChenMnZ commented 4 months ago

Please ref https://github.com/OpenGVLab/OmniQuant/blob/8e6ca67d1b2cf55ee1fd7a543e9d4f5d2bef261f/quantize/utils.py#L77 For weight without modification during LET, we should copy it to temp_weight.

zkf331 commented 4 months ago

Please ref

https://github.com/OpenGVLab/OmniQuant/blob/8e6ca67d1b2cf55ee1fd7a543e9d4f5d2bef261f/quantize/utils.py#L77

For weight without modification during LET, we should copy it to temp_weight.

Thank you for your prompt reply. According to your instructions, the problem has been solved. When the Codellama-34b utilizes LET, model.self_attn.o_proj.weight are not used. Therefore, copying them to temp_weight can resolve the issue.

brisker commented 4 months ago

@ChenMnZ why is model.mlp.down_proj.temp_weight = model.mlp.down_proj.weight but not model.mlp.down_proj.temp_weight = model.mlp.down_proj.weight.data.clone().detach() ??

Won't this change the original weight into quantized weight in smooth_and_quant_temporary function? And later the quantized weight will be quantized again in smooth_and_quant_inplace function

This seems to be a bug.