pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.67k stars 514 forks source link

GPTQ quantization not working #12

Open lopuhin opened 11 months ago

lopuhin commented 11 months ago

Running quantize.py with --mode int4-gptq does not seem to work:

Overall here are the fixes I had to apply to make it run: https://github.com/lopuhin/gpt-fast/commit/86d990bfbce46d10169c8e21e3bfec5cbd203b96

Based on this, could you please check if the right version of the code was included for GPTQ quantization?

lopuhin commented 11 months ago

One more issue is very high memory usage, it exceeds 128 GB after processing only the first 9 layers with the 13b model.

jamestwhedbee commented 11 months ago

I am at the third bullet point here as well, going to just follow along to comments here

lopuhin commented 11 months ago

@jamestwhedbee to get rid of those python issues you can try to use this fork in the meantime https://github.com/lopuhin/gpt-fast/ -- but I don't have a solution for high RAM usage yet, so in the end I didn't manage to get a converted model.

jamestwhedbee commented 11 months ago

That looked promising but I unfortunately ran into another issue you probably wouldn't have. I am on AMD so that might be the cause. I can't find anything online related to this issue. I noticed that non-GPTQ int4 quantization does not work for me either, with the same error. int8 quantization works fine and I have run GPTQ int4 quantized models using the auto-gptq library for ROCm before so not sure what this issue is.


Traceback (most recent call last):
  File "/home/telnyxuser/gpt-fast/quantize.py", line 614, in <module>
    quantize(args.checkpoint_path, args.model_name, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label)
  File "/home/telnyxuser/gpt-fast/quantize.py", line 560, in quantize
    quantized_state_dict = quant_handler.create_quantized_state_dict()
  File "/home/telnyxuser/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/telnyxuser/gpt-fast/quantize.py", line 423, in create_quantized_state_dict
    weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
  File "/home/telnyxuser/gpt-fast/quantize.py", line 358, in prepare_int4_weight_and_scales_and_zeros
    weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
  File "/home/telnyxuser/.local/lib/python3.10/site-packages/torch/_ops.py", line 753, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: _convert_weight_to_int4pack_cuda is not available for build.
lopuhin commented 11 months ago

I got the same error when trying a conversion on another machine with more RAM but an older NVIDIA GPU.

MrD005 commented 11 months ago

anyone solved all the problem. i am getting all the problem discussed in this thread

MrD005 commented 11 months ago

@jamestwhedbee @lopuhin i stuck on this Traceback (most recent call last): File "quantize.py", line 614, in quantize(args.checkpoint_path, args.model_name, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label) File "quantize.py", line 560, in quantize quantized_state_dict = quant_handler.create_quantized_state_dict() File "/root/development/dev/venv/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, *kwargs) File "quantize.py", line 423, in create_quantized_state_dict weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros( File "quantize.py", line 358, in prepare_int4_weight_and_scales_and_zeros weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles) File "/root/development/dev/venv/lib/python3.8/site-packages/torch/_ops.py", line 753, in call return self._op(args, **kwargs or {}) RuntimeError: _convert_weight_to_int4pack_cuda is not available for build.

are you guys able to solve this?

lopuhin commented 11 months ago

RuntimeError: _convert_weight_to_int4pack_cuda is not available for build.

@MrD005 I got this error when trying to run on 2080Ti but not on L4 (both using CUDA 12.1) so I suspect this is due to this function missing in lower compute capability.

MrD005 commented 11 months ago

@lopuhin i am running it on A100 , python 3.8 , with cuda 11.8 nightly so i think it is not about lower compute capability

chu-tianxiang commented 11 months ago

According to the code here, probably both cuda 12.x and compute capability 8.0+ are required.

briandw commented 11 months ago

I had the same _convert_weight_to_int4pack_cuda not available problem. It was due to Cuda 11.8 not supporting the operator. Works now with a RTX4090 and 12.1

xin-li-67 commented 11 months ago

I got this problem on my single RTX4090 with Pytorch nightly installed with Cuda 11.8. After I had switched to Pytorch nightly on CUDA12.1, the problem was gone.

lufixSch commented 10 months ago

@jamestwhedbee did you find a solution for ROCm?

jamestwhedbee commented 10 months ago

@lufixSch no, but as of last week v0.2.7 of vLLM supports GPTQ with ROCm, and I am seeing pretty good results there. So maybe that is an option for you.

ce1190222 commented 9 months ago

I applied all the fixes mentioned. But I'm still getting this error:- File "/kaggle/working/quantize.py", line 14, in from GPTQ import GenericGPTQRunner, InputRecorder File "/kaggle/working/GPTQ.py", line 12, in from eval import setup_cache_padded_seq_input_pos_max_seq_length_for_prefill File "/kaggle/working/eval.py", line 20, in import lm_eval.base ModuleNotFoundError: No module named 'lm_eval.base'

I am using lm_eval 0.4.0

jerryzh168 commented 9 months ago

lm_eval 0.3.0 and 0.4.0 support is updated in https://github.com/pytorch-labs/gpt-fast/commit/eb1789be0bdb7a7b75291f0839532ce1931305a2

petrex commented 1 week ago

GPTQ should be working for rocm ATM (rocm 6.2) , if not please let us know the detail.