Closed vinayakdsci closed 1 month ago
The error is from the onnx.ScatterElements
, torch2linalg is missing. https://github.com/nod-ai/SHARK-Turbine/issues/812
torch-mlir-opt --convert-torch-onnx-to-torch --torch-decompose-complex-ops --cse --canonicalize --convert-torch-to-linalg model.mlir --debug
%58 = torch.aten.scatter.reduce %arg4, %int0, %57, %48, %str : !torch.vtensor<[2708],f32>, !torch.int, !torch.vtensor<[?],si64>, !torch.vtensor<[?],f32>, !torch.str -> !torch.vtensor<[2708],f32>
pytorch c++ implementation https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp#L858
@AmosLewis One of the turbine camp folk (JeongHyun Leon Kim) wanted to pick this up. Are you currently working on this?
Torch constant int operation fails with
failed to legalize operation 'torch.constant.int'
during lowering, as reported in https://github.com/nod-ai/SHARK-Turbine/issues/812.IR to reproduce the issue
commands: