Open KenRanmzes opened 2 months ago
Thanks for the feedback! The problem is likely because 2-bit quantization is not supported in the current build. Is 2-bit an important feature in your use case? If so, I can provide a 2-bit build.
Thanks for the feedback! The problem is likely because 2-bit quantization is not supported in the current build. Is 2-bit an important feature in your use case? If so, I can provide a 2-bit build.
yes, 2Bits is important for me. Additionally , I found that the shape 3072, 256 is not support in TEMPLATE_TUNED_WITHOUT_M_CONFIGS
Got it, thanks for letting us know!
Adding 2-bit support is relatively easy, but it seems like the real problem is the lack of shape support for your use cases. This is part of a bigger feature we are trying to release. Currently, we have fixed shapes support, and we tuned them ourselves. We are hoping to get a just-in-time tuning feature out soon, and that should resolve your problem.
Unfortunately, this will likely take a few days and won't be a quick fix on our end. If 2-bit feature alone is useful to you, I can try to get that out relatively sooner. Let me know!
Traceback (most recent call last): File "/mnt/share/cq8/kennxiao/code/NF4Quant/test.py", line 41, in
prepare_model_flute(
File "/data/miniconda3/envs/env-3.9.2/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/mnt/share/cq8/kennxiao/code/NF4Quant/flute_convert.py", line 175, in prepare_model_flute
_replace_linear(name, module)
File "/mnt/share/cq8/kennxiao/code/NF4Quant/flute_convert.py", line 173, in _replace_linear
_replace_linear(child_full_name, child)
File "/mnt/share/cq8/kennxiao/code/NF4Quant/flute_convert.py", line 173, in _replace_linear
_replace_linear(child_full_name, child)
File "/mnt/share/cq8/kennxiao/code/NF4Quant/flute_convert.py", line 147, in _replace_linear
Q = flute.utils.pack(
File "/mnt/share/cq8/kennxiao/dev/flute/flute/utils.py", line 276, in pack
template_id = TEMPLATE_TUNED_WITHOUT_M_CONFIGS[(
KeyError: (108, 2, 128, 3072, 256, 'torch.float16')