Open zhangy659 opened 3 weeks ago
You need to patch the model for inference before. Because by default, the model is raady for QLoRa training which is not compatible with torch.compile
...
HQQLinear.set_backend(HQQBackend.ATEN)
from hqq.utils.patching import prepare_for_inference
prepare_for_inference(model)
#Inference
from hqq.utils.generation_hf import patch_model_for_compiled_runtime
patch_model_for_compiled_runtime(model, tokenizer, warmup=True)
...
If you are using the compiled runtime, you can also use the HQQLinear.set_backend(HQQBackend.PYTORCH)
which will work with both axis=0
and axis=1
, ATEN only works with axis=0
Thank you very much. I noticed that the patch_hqq_inference() in prepare_for_inference replaces the forward function with forward_hqq_inferece. This forward_hqq_inferece() is different from the forward_aten of HQQLinear itself. In this case, won't backend=aten not work?
Oh yes, it's a bit confusing, in short no.
ATEN, PYTORCH and PYTORCH_COMPILE are native backends for the hqq lib.
The ones in prepare_for_inference
are external CUDA / Triton kernels to speed-up inference, the HQQLinear layers are swapped with a whole new layer that uses those external backends. So when you patch_for_inference with a backend like torchao_int4
or bitblas
it will no longer use HQQBackend, it will use the CUDA / Triton kernels from those backends
Thank you very much. However, I don’t think this is the main issue. After PyTorch 2.4, the binding implementation for C++/CUDA operators has changed (https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html). When testing, I found that the methods in hqq/kernels are not compatible with torch.compile() and report the following issue: I'm not sure if this bug is due to an issue with the binding logic or a problem with my environment.
Oh could be actually, thanks for checking! Just use the PYTORCH
backend, then prepare_for_inference
, it should work.
I haven't used ATEN with axis=0 since months : D !
When I try to run patch_model_for_compiled_runtime on 8bit + aten, the program reports an error. How can I solve this problem?
code
import torch import torch.fx import time device = 'cuda:0' backend = 'torchao_int4' #"torchao_int4" (4-bit only) or "bitblas" (4-bit + 2-bit) compute_dtype = torch.float16 if backend=="bitblas" else torch.bfloat16 cache_dir = '.' model_id = './llama/llama3/Meta-Llama-3-8B' ########################################################################
Load model
from transformers import AutoModelForCausalLM, AutoTokenizer from hqq.models.hf.base import AutoHQQHFModel from hqq.core.quantize import *
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir) model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=cache_dir, torch_dtype=compute_dtype, attn_implementation="sdpa")
Quantize
quant_config = BaseQuantizeConfig(nbits=8, group_size=64, axis=0) AutoHQQHFModel.quantize_model(model, quant_config=quant_config, compute_dtype=compute_dtype, device=device) HQQLinear.set_backend(HQQBackend.ATEN_FORWARD)
Inference
from hqq.utils.generation_hf import patch_model_for_compiled_runtime
patch_model_for_compiled_runtime(model, tokenizer, warmup=True)
WARMUP_PROMPTS = [ "Write an essay about large language models.", "Tell me a funny joke!", "Hello, my name is Kiven, I like England for five reasons. First,", "Who is Elon Musk?", "Write a Python code snippet that adds two numbers together.", ] for prompt in WARMUP_PROMPTS: inputs_warmup = tokenizer(prompt,return_tensors='pt',padding='max_length',max_length=128,truncation=True).to(model.device) torch.cuda.synchronize() warmup_start = time.time() output = model.generate(**inputs_warmup,max_new_tokens=1000,cache_implementation="static", pad_token_id=tokenizer.pad_token_id) torch.cuda.synchronize() warmup_end = time.time() print(warmup_end-warmup_start)