Open rednoah91 opened 2 weeks ago
Hi,
I'd like to compile projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py into torch-mlir, so I did the modification:
projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py
--- a/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py +++ b/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py @@ -22,7 +22,8 @@ data = torch.randint(30522, (2, 128)) out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir" module = torchscript.compile( - model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=True + model, data, output_type="linalg-on-tensors", use_tracing=True )
But I encountered the following error:
python exception: Failure while executing pass pipeline: error: "__module.bert/__module.bert.bert/__module.bert.bert.encoder/__module.bert.bert.encoder.layer.0/__module.bert.bert.encoder.layer.0.attention/__module.bert.bert.encoder.layer.0.attention.self/aten::scaled_dot_product_attention"("/scratch/honghsu/torch-mlir/mlir_venv/lib/python3.11/site-packages/transformers/models/bert/modeling_bert.py":435:0): failed to legalize operation 'torch.aten.scaled_dot_product_attention' that was explicitly marked illegal note: "__module.bert/__module.bert.bert/__module.bert.bert.encoder/__module.bert.bert.encoder.layer.0/__module.bert.bert.encoder.layer.0.attention/__module.bert.bert.encoder.layer.0.attention.self/aten::scaled_dot_product_attention"("/scratch/honghsu/torch-mlir/mlir_venv/lib/python3.11/site-packages/transformers/models/bert/modeling_bert.py":435:0): see current operation: %111 = "torch.aten.scaled_dot_product_attention"(%100, %105, %110, %93, %48, %58, %57) : (!torch.vtensor<[2,2,128,64],f32>, !torch.vtensor<[2,2,128,64],f32>, !torch.vtensor<[2,2,128,64],f32>, !torch.vtensor<[2,1,128,128],f32>, !torch.float, !torch.bool, !torch.none) -> !torch.vtensor<[2,2,128,64],f32>
My environment:
torch-mlir: 20240608.126 transformers: 4.41.2
Have you ever encountered this or did I miss something?
Thanks.
Looks like only default inputs are supported, i'm not sure if there is a plan to support this
https://github.com/llvm/torch-mlir/blob/ca0e9066755b35c0889c6ab792265b0886325f50/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp#L1573-L1575
Hi,
I'd like to compile
projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py
into torch-mlir, so I did the modification:But I encountered the following error:
My environment:
Have you ever encountered this or did I miss something?
Thanks.