llvm / torch-mlir

The Torch-MLIR project aims to provide first class support from the PyTorch ecosystem to the MLIR ecosystem.
Other
1.23k stars 448 forks source link

Compile torch.nn.functional.scaled_dot_product_attention failed #3473

Open rednoah91 opened 2 weeks ago

rednoah91 commented 2 weeks ago

Hi,

I'd like to compile projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py into torch-mlir, so I did the modification:

--- a/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py
+++ b/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py
@@ -22,7 +22,8 @@ data = torch.randint(30522, (2, 128))
 out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir"

 module = torchscript.compile(
-    model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=True
+    model, data, output_type="linalg-on-tensors", use_tracing=True
 )

But I encountered the following error:

python exception: Failure while executing pass pipeline:
error: "__module.bert/__module.bert.bert/__module.bert.bert.encoder/__module.bert.bert.encoder.layer.0/__module.bert.bert.encoder.layer.0.attention/__module.bert.bert.encoder.layer.0.attention.self/aten::scaled_dot_product_attention"("/scratch/honghsu/torch-mlir/mlir_venv/lib/python3.11/site-packages/transformers/models/bert/modeling_bert.py":435:0): failed to legalize operation 'torch.aten.scaled_dot_product_attention' that was explicitly marked illegal
note: "__module.bert/__module.bert.bert/__module.bert.bert.encoder/__module.bert.bert.encoder.layer.0/__module.bert.bert.encoder.layer.0.attention/__module.bert.bert.encoder.layer.0.attention.self/aten::scaled_dot_product_attention"("/scratch/honghsu/torch-mlir/mlir_venv/lib/python3.11/site-packages/transformers/models/bert/modeling_bert.py":435:0): see current operation: %111 = "torch.aten.scaled_dot_product_attention"(%100, %105, %110, %93, %48, %58, %57) : (!torch.vtensor<[2,2,128,64],f32>, !torch.vtensor<[2,2,128,64],f32>, !torch.vtensor<[2,2,128,64],f32>, !torch.vtensor<[2,1,128,128],f32>, !torch.float, !torch.bool, !torch.none) -> !torch.vtensor<[2,2,128,64],f32>

My environment:

torch-mlir: 20240608.126
transformers: 4.41.2

Have you ever encountered this or did I miss something?

Thanks.

IanWood1 commented 4 days ago

Looks like only default inputs are supported, i'm not sure if there is a plan to support this

https://github.com/llvm/torch-mlir/blob/ca0e9066755b35c0889c6ab792265b0886325f50/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp#L1573-L1575