mobiusml / hqq

Official implementation of Half-Quadratic Quantization (HQQ)
https://mobiusml.github.io/hqq_blog/
Apache License 2.0
705 stars 70 forks source link

8bit + Aten + compile #130

Open zhangy659 opened 3 weeks ago

zhangy659 commented 3 weeks ago

When I try to run patch_model_for_compiled_runtime on 8bit + aten, the program reports an error. How can I solve this problem? image

code

import torch import torch.fx import time device = 'cuda:0' backend = 'torchao_int4' #"torchao_int4" (4-bit only) or "bitblas" (4-bit + 2-bit) compute_dtype = torch.float16 if backend=="bitblas" else torch.bfloat16 cache_dir = '.' model_id = './llama/llama3/Meta-Llama-3-8B' ########################################################################

Load model

from transformers import AutoModelForCausalLM, AutoTokenizer from hqq.models.hf.base import AutoHQQHFModel from hqq.core.quantize import *

tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir) model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=cache_dir, torch_dtype=compute_dtype, attn_implementation="sdpa")

Quantize

quant_config = BaseQuantizeConfig(nbits=8, group_size=64, axis=0) AutoHQQHFModel.quantize_model(model, quant_config=quant_config, compute_dtype=compute_dtype, device=device) HQQLinear.set_backend(HQQBackend.ATEN_FORWARD)

Inference

from hqq.utils.generation_hf import patch_model_for_compiled_runtime

patch_model_for_compiled_runtime(model, tokenizer, warmup=True)

WARMUP_PROMPTS = [ "Write an essay about large language models.", "Tell me a funny joke!", "Hello, my name is Kiven, I like England for five reasons. First,", "Who is Elon Musk?", "Write a Python code snippet that adds two numbers together.", ] for prompt in WARMUP_PROMPTS: inputs_warmup = tokenizer(prompt,return_tensors='pt',padding='max_length',max_length=128,truncation=True).to(model.device) torch.cuda.synchronize() warmup_start = time.time() output = model.generate(**inputs_warmup,max_new_tokens=1000,cache_implementation="static", pad_token_id=tokenizer.pad_token_id) torch.cuda.synchronize() warmup_end = time.time() print(warmup_end-warmup_start)

mobicham commented 3 weeks ago

You need to patch the model for inference before. Because by default, the model is raady for QLoRa training which is not compatible with torch.compile

...
HQQLinear.set_backend(HQQBackend.ATEN)
from hqq.utils.patching import prepare_for_inference
prepare_for_inference(model)

#Inference
from hqq.utils.generation_hf import patch_model_for_compiled_runtime
patch_model_for_compiled_runtime(model, tokenizer, warmup=True)
...
mobicham commented 3 weeks ago

If you are using the compiled runtime, you can also use the HQQLinear.set_backend(HQQBackend.PYTORCH) which will work with both axis=0 and axis=1, ATEN only works with axis=0

zhangy659 commented 3 weeks ago

Thank you very much. I noticed that the patch_hqq_inference() in prepare_for_inference replaces the forward function with forward_hqq_inferece. This forward_hqq_inferece() is different from the forward_aten of HQQLinear itself. In this case, won't backend=aten not work?

mobicham commented 3 weeks ago

Oh yes, it's a bit confusing, in short no. ATEN, PYTORCH and PYTORCH_COMPILE are native backends for the hqq lib. The ones in prepare_for_inference are external CUDA / Triton kernels to speed-up inference, the HQQLinear layers are swapped with a whole new layer that uses those external backends. So when you patch_for_inference with a backend like torchao_int4 or bitblas it will no longer use HQQBackend, it will use the CUDA / Triton kernels from those backends

zhangy659 commented 3 weeks ago

Thank you very much. However, I don’t think this is the main issue. After PyTorch 2.4, the binding implementation for C++/CUDA operators has changed (https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html). When testing, I found that the methods in hqq/kernels are not compatible with torch.compile() and report the following issue: image I'm not sure if this bug is due to an issue with the binding logic or a problem with my environment.

mobicham commented 3 weeks ago

Oh could be actually, thanks for checking! Just use the PYTORCH backend, then prepare_for_inference, it should work. I haven't used ATEN with axis=0 since months : D !