mirage-project / mirage

A multi-level tensor algebra superoptimizer
https://mirage-project.readthedocs.io/
Apache License 2.0
341 stars 18 forks source link

Using Mirage for different models #27

Open ramyaprabhu-alt opened 3 months ago

ramyaprabhu-alt commented 3 months ago

Hi, I just found this repository and I really like the idea. I wanted to try it out for a different model that what's on the readme for this repo, like say LLama 3 8B TP2. But I'm a novice and am struggling to understand how the inputs in the given example must be modified

image

1.) From all the looking around I assume this is a prefill kernel for cl 4096 with input size of 256. am I correct? 2.) for each of the Q K and V tensors, why is the first of the input dims tuple 2 * batch size? 3.) what is 64 in the input dims tuple supposed to be? because the readme says this is a kernel for Llama 70B tp4...