Open ElEHsiang opened 6 months ago
Our current implementation of scaled dot product attention doesn't yet support is_causal (should be coming soon). For now, I would suggest decomposing the attention op. @gpetters-amd, weren't you looking at adding that recently?
You can see an example of how we added that as an option in our SDXL pipeline here when you add the flag --decomp_attn https://github.com/nod-ai/SHARK-Turbine/blob/f919efe78903727d149c997e326f41b54ea1e147/models/turbine_models/custom_models/sdxl_inference/sdxl_schedulers.py#L117
Our current implementation of scaled dot product attention doesn't yet support is_causal (should be coming soon). For now, I would suggest decomposing the attention op. @gpetters-amd, weren't you looking at adding that recently?
You can see an example of how we added that as an option in our SDXL pipeline here when you add the flag --decomp_attn https://github.com/nod-ai/SHARK-Turbine/blob/f919efe78903727d149c997e326f41b54ea1e147/models/turbine_models/custom_models/sdxl_inference/sdxl_schedulers.py#L117
I am getting the same error even after adding the decomposition. These were added
with decompositions.extend_aot_decompositions(
from_current=True,
add_ops=[
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten.masked_fill_.Scalar,
torch.ops.aten.copy,
],
):
I tried to compile
TinyLlama-1.1B-Chat-v1.0
model to vmfb but failed. The parameter data type unmatch in torch.nn.functional.scaled_dot_product_attention(). How can I fix it?PS. I based on commit
4a01c405843fd91badbea2a14fd19e0393aade8f
due to iree-turbine does not have stateless_llama.py.Command
Error log