linkedin / Liger-Kernel

Efficient Triton Kernels for LLM Training
BSD 2-Clause "Simplified" License
3.09k stars 158 forks source link

inference qwen2 model ,The reasoning is garbled and ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) #268

Open Dujianhua1008 opened 4 days ago

Dujianhua1008 commented 4 days ago

🐛 Describe the bug

when I load model with AutoLigerKernelForCausalLM ,I get ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)

image

when load mdoel Apply Model-Specific Patching APIs ,,The reasoning is garbled

image

Reproduce


import os 
os.environ['CUDA_VISIBLE_DEVICES']='0,1'
import torch
from transformers import AutoTokenizer,AutoModelForCausalLM
from liger_kernel.transformers import AutoLigerKernelForCausalLM

model_path="Qwen/Qwen2-7B-Instruct"
tokenizer=AutoTokenizer.from_pretrained(model_path)
model=AutoModelForCausalLM.from_pretrained(model_path,trust_remote_code=True,device_map='cuda:0')

from liger_kernel.transformers import apply_liger_kernel_to_qwen2
apply_liger_kernel_to_qwen2()
apply_liger_kernel_to_qwen2(
    rope=True,
    swiglu=True,
    cross_entropy=True,
    fused_linear_cross_entropy=False,
    rms_norm=False,
    model=model,
)

Liger_model=AutoLigerKernelForCausalLM.from_pretrained(model_path,trust_remote_code=True,device_map='cuda:1')

def generate(model,model_inputs,config):
    # device=next(model.model.parameters()).device
    # model_inputs.to(device)
# Generate
    with torch.cuda.amp.autocast():
        generated_ids = model.generate(
                model_inputs.input_ids,
                min_new_tokens=config['min_new_tokens'],
                max_new_tokens=config['max_new_tokens'],
                # pad_token_id=tokenizer.eos_token_id
            )
    generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    print(len(generated_ids[0]))
    print(response)
prompt = "Hey, are you conscious? Can you talk to me?"
model_inputs = tokenizer(prompt, return_tensors="pt").to('cuda:0')
config={
    'min_new_tokens':512,
    'max_new_tokens':512
}
generate(Liger_model,model_inputs,config)
# error with  cpu and gpu
generate(model,model_inputs,config)
# error with 

### Versions

Environment Report:
-------------------
Operating System: Linux-5.15.0-52-generic-x86_64-with-glibc2.35
Python version: 3.10.14
PyTorch version: 2.3.1
CUDA version: 12.1
Triton version: 3.0.0
Transformers version: 4.43.4
chiwanpark commented 1 day ago

@Dujianhua1008 cc @tyler-romero Here is a code snippet where I succeed the inference:

import os 
os.environ['CUDA_VISIBLE_DEVICES']='0,1'
import torch
from transformers import AutoTokenizer,AutoModelForCausalLM
from liger_kernel.transformers import AutoLigerKernelForCausalLM

model_path="Qwen/Qwen2-7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)

from liger_kernel.transformers import apply_liger_kernel_to_qwen2
apply_liger_kernel_to_qwen2(
    rope=True,
    swiglu=True,
    cross_entropy=True,
    fused_linear_cross_entropy=False,
    rms_norm=True,
)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, device_map="cuda:0")

def generate(model, model_inputs, config):
    # Generate
    with torch.cuda.amp.autocast():
        generated_ids = model.generate(
            model_inputs.input_ids,
            min_new_tokens=config['min_new_tokens'],
            max_new_tokens=config['max_new_tokens'],
        )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    print(len(generated_ids[0]))
    print(response)

prompt = "Hey, are you conscious? Can you talk to me?"
model_inputs = tokenizer(prompt, return_tensors="pt").to('cuda:0')
config={
    'min_new_tokens':512,
    'max_new_tokens':512
}

generate(model, model_inputs, config)

It seems that there is a bug when we apply liger kernel to an already instanciated model. I have started a discussion for that in GPU MODE discord.