jy-yuan / KIVI

KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache
https://arxiv.org/abs/2402.02750
MIT License
218 stars 21 forks source link

ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) #24

Closed xzwj1699 closed 6 days ago

xzwj1699 commented 3 months ago

I met an error when I tried KIVI, and here is the code. (I modify the example.py in order to run in my server)

# LLaMA model with KIVI
import warnings
warnings.filterwarnings("ignore")
import torch
import random
from models.llama_kivi import LlamaForCausalLM_KIVI
from models.mistral_kivi import MistralForCausalLM_KIVI
from transformers import LlamaConfig, AutoTokenizer, MistralConfig
from datasets import load_dataset
from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map

# For reproducibility
random.seed(0)
torch.manual_seed(0)

model_name = "mistralai/Mistral-7B-Instruct-v0.2"
# config = LlamaConfig.from_pretrained("meta-llama/Llama-2-7b-hf")
config = MistralConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

config.k_bits = 2 # KiVi currently support 2/4 K/V bits
config.v_bits = 2
config.group_size = 32 
config.residual_length = 32 # corresponding to the number of recent fp16 tokens
CACHE_DIR = "./"

with init_empty_weights():
    model = MistralForCausalLM_KIVI.from_pretrained(
        pretrained_model_name_or_path=model_name,
        config=config,
        cache_dir=CACHE_DIR,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
    )

device_map = {
    'model.embed_tokens': 0,
    'model.layers.0': 0,
    'model.layers.1': 0,
    'model.layers.2': 0,
    'model.layers.3': 1,
    'model.layers.4': 1,
    'model.layers.5': 1,
    'model.layers.6': 1,
    'model.layers.7': 2,
    'model.layers.8': 2,
    'model.layers.9': 2,
    'model.layers.10': 2,
    'model.layers.11': 3,
    'model.layers.12': 3,
    'model.layers.13': 3,
    'model.layers.14': 3,
    'model.layers.15': 0,
    'model.layers.16': 0,
    'model.layers.17': 0,
    'model.layers.18': 1,
    'model.layers.19': 1,
    'model.layers.20': 1,
    'model.layers.21': 2,
    'model.layers.22': 2,
    'model.layers.23': 2,
    'model.layers.24': 2,
    'model.layers.25': 3,
    'model.layers.26': 3,
    'model.layers.27': 3,
    'model.layers.28': 3,
    'model.layers.29': 2,
    'model.layers.30': 2,
    'model.layers.31': 2,
    'model.norm': 1,
    'lm_head': 1
}

model = load_checkpoint_and_dispatch(
    model,
    checkpoint=model_name,
    device_map=device_map,
    offload_folder=None,
    dtype=torch.float16
)

enc = AutoTokenizer.from_pretrained(
    model_name, 
    use_fast=False, 
    trust_remote_code=True)

dataset = load_dataset('gsm8k', 'main')

prompt = ''
for i in range(5):
    prompt += 'Question: ' + dataset['train'][i]['question'] + '\nAnswer: ' + dataset['train'][i]['answer'] + '\n'
prompt += "Question: John takes care of 10 dogs. Each dog takes .5 hours a day to walk and take care of their business. How many hours a week does he spend taking care of dogs?"
inputs = enc(prompt, return_tensors="pt").input_ids.cuda()

output = model.generate(inputs, max_new_tokens=96)
config_str = f"# prompt tokens: {inputs.shape[1]}, K bit: {config.k_bits}, v_bits: {config.v_bits}, group_size: {config.group_size}, residual_length: {config.residual_length}"

print(prompt + "\n" + "=" * 10 + f'\n{config_str}\n' + "=" * 10 + "\nKiVi Output:")
print(enc.decode(output[0].tolist()[inputs.shape[1]:], skip_special_tokens=True))

and here is the full error log

(kivi) root@39e0295decad:~/workspace/KIVI# python example.py 
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00,  1.26it/s]                                                                                                                                                                                                              
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Traceback (most recent call last):
  File "/root/workspace/KIVI/example.py", line 111, in <module>
    output = model.generate(inputs, max_new_tokens=96)
  File "/root/workspace/miniconda/envs/kivi/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/workspace/miniconda/envs/kivi/lib/python3.10/site-packages/transformers/generation/utils.py", line 1718, in generate
    return self.greedy_search(
  File "/root/workspace/miniconda/envs/kivi/lib/python3.10/site-packages/transformers/generation/utils.py", line 2579, in greedy_search
    outputs = self(
  File "/root/workspace/miniconda/envs/kivi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/workspace/miniconda/envs/kivi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/workspace/miniconda/envs/kivi/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/root/workspace/KIVI/models/mistral_kivi.py", line 999, in forward
    outputs = self.model(
  File "/root/workspace/miniconda/envs/kivi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/workspace/miniconda/envs/kivi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/workspace/KIVI/models/mistral_kivi.py", line 887, in forward
    layer_outputs = decoder_layer(
  File "/root/workspace/miniconda/envs/kivi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/workspace/miniconda/envs/kivi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/workspace/miniconda/envs/kivi/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/root/workspace/KIVI/models/mistral_kivi.py", line 719, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/root/workspace/miniconda/envs/kivi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/workspace/miniconda/envs/kivi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/workspace/miniconda/envs/kivi/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/root/workspace/KIVI/models/mistral_kivi.py", line 249, in forward
    key_states_quant_trans, key_scale_trans, key_mn_trans = triton_quantize_and_pack_along_last_dim(key_states_quant.transpose(2, 3).contiguous(), self.group_size, self.k_bits)
  File "/root/workspace/KIVI/quant/new_pack.py", line 232, in triton_quantize_and_pack_along_last_dim
    _minmax_along_last_dim[grid](data, mn, mx,
  File "<string>", line 45, in _minmax_along_last_dim
ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)

I guess that there may be something wrong with my modified code, and it would great if you could help.

Thank you!

condy0919 commented 1 week ago

You have to manually set cuda device before running triton kernel. e.g.,

with torch.cuda.device(data.device):
    _minmax_along_last_dim[grid](data, mn, mx, data.numel(), data.shape[0], num_groups, group_size,  BLOCK_SIZE_N=BLOCK_SIZE_N, num_warps=8)