tenstorrent / tt-mlir

Tenstorrent MLIR compiler
https://tenstorrent.github.io/tt-mlir/
Apache License 2.0
71 stars 9 forks source link

Stablehlo MINIST Softmax test is failing due to Reduction Op #805

Open uazizTT opened 1 month ago

uazizTT commented 1 month ago

pytest -svv tests/TTIR/test_mnist.py::test_softmax

fails with error:

Always | FATAL | keepdim=False is not supported

The keep_dim flag for reduction Op is always set to false. It might require changing it to true followed by a reshape Op.

AleksKnezevic commented 1 month ago

@mrakitaTT is looking into this.

AleksKnezevic commented 1 month ago

Please enable the tt-xla test once this has been resolved.

mrakitaTT commented 3 weeks ago

I've made a fix for this last week to make softmax work (on the softmax branch), but after yesterday's discussion I am not sure if we want to check it in or we want to pursue this on the Metal side.

Basically ttnn doesn't support calling reduce ops with keepdim=False so workaround is to call them with keepdim=True and then to manually reshape the returned tensor to remove the reduced dimensions.

I've opened an issue for this to Metal team: https://github.com/tenstorrent/tt-metal/issues/13361

@tt-mpantic could you please help us escalate this?

AleksKnezevic commented 3 weeks ago

@LPanosTT, fyi.

mrakitaTT commented 1 week ago

Moving to External ttnn/tt-metal component so this can be escalated to Metal team. Link to the metal issue: https://github.com/tenstorrent/tt-metal/issues/13361

@tt-mpantic could you please help us escalate this? (sent dm on slack too)

mrakitaTT commented 1 week ago

@nsmithtt @sdjordjevicTT In the meantime while this is being escalated to the Metal team, can we discuss the potential workaround? Because to me it doesn't seem likely that this will be fixed soon by Metal team and this issue is open for too long on our side, other people keep stumbling upon this issue and fixing it themselves (example: https://github.com/tenstorrent/tt-mlir/pull/899).

Here is the proposed workaround decomposition (copying comment from my commit):

// StableHLO.ReduceOp always removes reduce dimensions from the result shape
// so ideally we would just convert it to TTIR.ReduceOp with keepDim=False.
// Unfortunately we cannot do this because Metal TTNN implementation of
// Reduce doesn't yet support keepDim=False. As a workaround, we convert it
// to combination of TTIR.ReduceOp with keepDim=True + TTIR.ReshapeOp to
// remove the reduce dims so that the rest of the graph is not affected.
// In case when this is not needed (because type converter already promoted
// rank of the op result) then we avoid adding unnecessary Reshape op.
nsmithtt commented 1 week ago

@mrakitaTT, the workaround sounds good to me, but let's do the workaround inside of TTNN dialect passes, not TTIR. @sdjordjevicTT has a document that outlines our plan for workarounds in TTNN, I tagged you in a comment on the doc.