nod-ai / SHARK-Turbine

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
90 stars 45 forks source link

LSTM #315

Open renxida opened 8 months ago

renxida commented 8 months ago

Looks like there isn't a corresponding torch lstm op. Looking into how to implement it.

vivekkhandelwal1 commented 8 months ago

Hi @renxida, there's a corresponding LSTM op in PyTorch, see: https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html. You can add OnnxToTorch and TorchToLinalg lowering for LSTM op. This would be a good learning experience as well.

stellaraccident commented 8 months ago

You'll likely want to decompose it versus converting to linalg directly. There is rarely value to a compiler to keep lstm as a named op. It is a trivial fusion.

vivekkhandelwal1 commented 8 months ago

You'll likely want to decompose it versus converting to linalg directly. There is rarely value to a compiler to keep lstm as a named op. It is a trivial fusion.

Yeah, that's a better approach. This might be of some help. https://github.com/pytorch/pytorch/blob/a97d00cca5c1f47e74048f110df5706669a84e6e/torch/_decomp/decompositions.py#L3135

vivekkhandelwal1 commented 8 months ago

Btw @stellaraccident, while importing the onnx models there's a drawback that we can't make use of the existing PyTorch decompositions.

stellaraccident commented 8 months ago

Yeah, I know. There's no other way to do it though: onnx is a c++ only technology.

vivekkhandelwal1 commented 7 months ago

Hi @renxida, are you still working on this op?

renxida commented 7 months ago

@vivekkhandelwal1 nope currently not. Feel free to grab it.

Right now i'm adding onnx end to end test cases in @kumardeepakamd 's SHARK-TestSuite repo

kumardeepakamd commented 6 months ago

@vivekkhandelwal1 and @renxida has this landed and can be marked completed? Are e2eshark onnx and/or pytorch tests and torch-mlir lit and torch op tests added?

renxida commented 5 months ago

This will be fixed by https://github.com/llvm/torch-mlir/pull/2969

AmosLewis commented 1 month ago

python ./run.py --tolerance 0.001 0.001 --cachedir /proj/gdba/shark/cache -f onnx -g models --mode onnx --report --tests onnx/models/sequencer2d_m_vaiq

sequencer2d_m_vaiq.default.onnx.torch.mlir:426:14: error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
    %422:3 = torch.operator "onnx.LSTM"(%403, %181, %182, %180, %none, %412, %421) {torch.onnx.direction = "bidirectional", torch.onnx.hidden_size = 48 : si64} : (!torch.vtensor<[32,32,192],f32>, !torch.vtensor<[2,192,192],f32>, !torch.vtensor<[2,192,48],f32>, !torch.vtensor<[2,384],f32>, !torch.none, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>) -> (!torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>) 
             ^
sequencer2d_m_vaiq.default.onnx.torch.mlir:426:14: note: see current operation: %940:3 = "torch.operator"(%837, %363, %365, %361, %792, %888, %939) <{name = "onnx.LSTM"}> {torch.onnx.direction = "bidirectional", torch.onnx.hidden_size = 48 : si64} : (!torch.vtensor<[32,32,192],f32>, !torch.vtensor<[2,192,192],f32>, !torch.vtensor<[2,192,48],f32>, !torch.vtensor<[2,384],f32>, !torch.none, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>) -> (!torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>)

@renxida need relook

renxida commented 1 month ago

aw crap looks like i have to support bidirectional onnx for this model to work