triton-lang / triton-cpu

An experimental CPU backend for Triton
https://triton-lang.org/
MIT License
55 stars 15 forks source link

error: LLVM Translation failed for operation: builtin.unrealized_conversion_cast #144

Open vivekvpandya opened 1 month ago

vivekvpandya commented 1 month ago

Trying to run following gpt2 demo with triton-cpu and certain tirton kernel fails with above error:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch._inductor import config

@config.patch(cpu_backend="triton")
@torch.compile(backend="inductor")
def generate(prompt, model, tokenizer):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    gen_tokens = model.generate(
        input_ids,
       # max_length=4,
       # do_sample=False,
       # pad_token_id=tokenizer.eos_token_id,
    )
    gen_text = tokenizer.batch_decode(gen_tokens)[0]
    return gen_text

def test_gpt2_demo():
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    model = AutoModelForCausalLM.from_pretrained("gpt2").to(torch.float)
    prompt = "Thanks for"
    print("\nInput prompt: ", prompt, "\n")

    # run on CPU
    gen_text = generate(prompt, model, tokenizer)

    print("CPU output: ", gen_text, "\n")

test_gpt2_demo()                 

Dump of triton IR with more details are attached in log file. log.txt

cc @ienkovich @minjang

ienkovich commented 1 month ago

From what I see in the log, this is due to the missing lowering for tl.device_assert. Asserts support has been added just recently: aa6a64e476f6fa4a50c69f626c09a312a2db65c8. Did you use the latest version with this patch included? If the latest version still fails, a Triton kernel reproducer would help a lot to debug the issue.