nod-ai / SHARK-Turbine

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
82 stars 41 forks source link

stateless_llama_test test_vmfb_comparison broken with PyTorch 2.3 #601

Open saienduri opened 3 months ago

saienduri commented 3 months ago

See: models/turbine_models/tests/stateless_llama_test.py

Marked as expectedFailure.

This test is failing during the export/tracing stage with the following error:

FAILED models/turbine_models/tests/stateless_llama_test.py::StatelessLlamaChecks::test_vmfb_comparison - TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool

From dynamic shapes manual in "Error Cookbook" section, it seems that we have to make changes to the operator schema to accept SymBool and not only bools in torch. But I was able to workaround it for now by changing in modelling_llama.py:

is_causal=self.is_causal and attention_mask is None and q_len > 1

to

True if is_causal=self.is_causal and attention_mask is None and q_len > 1 else False

But, after this change, ran into this error in iree-compile stage:

Diagnostics:
E               <stdin>:1101:13: error: failed to legalize operation 'torch.aten.scaled_dot_product_attention' that was explicitly marked illegal
E                   %73:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%64, %72, %44, %float0.000000e00, %true_61, %none_62, %none_63) : (!torch.vtensor<[1,32,?,128],f32>, !torch.vtensor<[1,32,?,128],f32>, !torch.vtensor<[1,32,?,128],f32>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,32,?,128],f32>, !torch.vtensor<[1,32,?],f32>)

Do we need to register a decomposition for this op for this to work?

IanNod commented 3 months ago

@saienduri what backend/device were you facing this issue on, and did you also try testing with an updated IREE build?

@Groverkss is there any expected failure cases like this with SDPA enabled for now?

saienduri commented 3 months ago

This is on the cpu backend, and I was trying with the one pinned in turbine from 04/03. I just tried with the latest iree releases, but still the same issue.