Open dvlp-r opened 1 year ago
This is likely an error due to torch-mlir
not having a way to represent the SparseTensor
type. I think setting use_tracing=True
in torch_mlir.compile
might fix it, since the isinstance(...)
would get evaluated and inserted as a constant in the graph.
unfortunately use_tracing did not help
/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmpgqejm00y.py:63: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if edge_index.size(0) != 2:
/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmpgqejm00y.py:116: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
init = torch.tensor(0.)
/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmpefq5y_le.py:63: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if edge_index.size(0) != 2:
/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmpefq5y_le.py:116: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
init = torch.tensor(0.)
/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmpecbnml5b.py:63: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if edge_index.size(0) != 2:
/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmpecbnml5b.py:116: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
init = torch.tensor(0.)
/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmpxialucq8.py:63: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if edge_index.size(0) != 2:
/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmpxialucq8.py:116: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
init = torch.tensor(0.)
/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmp5yfadsny.py:63: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if edge_index.size(0) != 2:
/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmp5yfadsny.py:116: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
init = torch.tensor(0.)
Iteration: 0%| | 0/4113 [00:01<?, ?it/s]
Traceback (most recent call last):
File "/Users/dvlpr/torch-mlir/examples/gnn/gin.py", line 110, in <module>
module = torch_mlir.compile(gin, (x, edge_index, edge_attr, batch_f), output_type="linalg-on-tensors", use_tracing=True)
File "/Users/dvlpr/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 444, in compile
return _lower_mlir_module(verbose, output_type, mb.module)
File "/Users/dvlpr/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 286, in _lower_mlir_module
run_pipeline_with_repro_report(
File "/Users/dvlpr/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 75, in run_pipeline_with_repro_report
raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed with the following diagnostics:
python exception: Failure while executing pass pipeline:
error: "/Users/dvlpr/mlir_venv/lib/python3.10/site-packages/torch_geometric/utils/scatter.py":78:0: failed to legalize operation 'torch.constant.int'
note: "/Users/dvlpr/mlir_venv/lib/python3.10/site-packages/torch_geometric/utils/scatter.py":78:0: see current operation: %58 = "torch.constant.int"() {value = 0 : i64} : () -> !torch.int
For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)' /var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/GNN.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.
The error is different now. What's likely happening now is that there is a different op being used by the model that is not yet supported in torch-mlir. If you have a link to the temporary file /var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/GNN.mlir
generated during error, I can take a look and see what is missing.
Hi @ramiro050, I leave here the link to a gist with the code of the GNN.mlir file. Thanks for your help.
The issue is that there is currently no TorchToLinalg
pattern for the AtenScatterAdd
op.
Hi @ramiro050 , thanks for you help. I followed your suggestion and I tried to remove the scatter dependency using a different graph pooling. But now this error appears and it is the same as before but referred to a class of OGB package (Open Graph Benchmark).
I looked at it thinking it was using scatter but it is not the case. So I find this a little bit strange, because the error seems to be the same. But also the GNN.mlir
still make use of scatter_add
so it could be that the problem is still that one but I do not get why the error log link to the mol_encoder.py
file and not more to scatter as before.
/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmplrzcc_vc.py:63: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if edge_index.size(0) != 2:
/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmplrzcc_vc.py:116: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
init = torch.tensor(0.)
/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmpp61rshb5.py:63: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if edge_index.size(0) != 2:
/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmpp61rshb5.py:116: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
init = torch.tensor(0.)
/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmpqnbi7y_9.py:63: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if edge_index.size(0) != 2:
/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmpqnbi7y_9.py:116: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
init = torch.tensor(0.)
/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmpik76df4v.py:63: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if edge_index.size(0) != 2:
/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmpik76df4v.py:116: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
init = torch.tensor(0.)
/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmp0kr5secw.py:63: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if edge_index.size(0) != 2:
/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmp0kr5secw.py:116: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
init = torch.tensor(0.)
Iteration: 0%| | 0/4113 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/Users/dvlpr/torch-mlir/examples/gnn/gin.py", line 110, in <module>
module = torch_mlir.compile(gin, (x, edge_index, edge_attr, batch_f), output_type="linalg-on-tensors", use_tracing=True)
File "/Users/dvlpr/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 444, in compile
return _lower_mlir_module(verbose, output_type, mb.module)
File "/Users/dvlpr/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 286, in _lower_mlir_module
run_pipeline_with_repro_report(
File "/Users/dvlpr/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 75, in run_pipeline_with_repro_report
raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed with the following diagnostics:
python exception: Failure while executing pass pipeline:
error: "/Users/dvlpr/mlir_venv/lib/python3.10/site-packages/ogb/graphproppred/mol_encoder.py":21:0: failed to legalize operation 'torch.constant.int'
note: "/Users/dvlpr/mlir_venv/lib/python3.10/site-packages/ogb/graphproppred/mol_encoder.py":21:0: see current operation: %64 = "torch.constant.int"() {value = 0 : i64} : () -> !torch.int
For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)' /var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/GNN.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.
I leave here the link to the new GNN.mlir file in case you will have time to look at it. Thanks
The issue is the same one: scatter_add
is being used and there is no support for it. The error message is a bit confusing. What's happening is that aten.scatter_add
is not getting converted to linalg-on-tensors
, which causes the arguments to the op to also stay in the graph. Because the arguments come first in the graph, torch-mlir fails on the torch.constant.int 0
rather than the scatter_add
op, and the constant int has source code location in mol_encoder.py
. However, the scatter_add
op is from the scatter.py
file.
Hope this helps!
Thank you so much @ramiro050 for your exhaustive explanation. Do you have any suggestion on where to start from to try to implement the support for aten.scatter_add
?
For scatter ops, we need to lower to the TMTensor
dialect. Here's an example of adding a similar op: https://github.com/llvm/torch-mlir/commit/552887783a58376842d3b2ca64f97f8dcd84a347
Let me know if you have any questions
Hi, when trying to use torch_mlir compile with a model using sparse tensor, the following error appears
I know torch_mlir does not support sparse tensor (they are lowered as dense), but this error preventing the lowering at all. Does someone know how to solve it? Thanks