Open uazizTT opened 1 month ago
@mrakitaTT is looking into this.
Please enable the tt-xla test once this has been resolved.
I've made a fix for this last week to make softmax work (on the softmax branch), but after yesterday's discussion I am not sure if we want to check it in or we want to pursue this on the Metal side.
Basically ttnn doesn't support calling reduce ops with keepdim=False
so workaround is to call them with keepdim=True
and then to manually reshape the returned tensor to remove the reduced dimensions.
I've opened an issue for this to Metal team: https://github.com/tenstorrent/tt-metal/issues/13361
@tt-mpantic could you please help us escalate this?
@LPanosTT, fyi.
Moving to External ttnn/tt-metal
component so this can be escalated to Metal team. Link to the metal issue: https://github.com/tenstorrent/tt-metal/issues/13361
@tt-mpantic could you please help us escalate this? (sent dm on slack too)
@nsmithtt @sdjordjevicTT In the meantime while this is being escalated to the Metal team, can we discuss the potential workaround? Because to me it doesn't seem likely that this will be fixed soon by Metal team and this issue is open for too long on our side, other people keep stumbling upon this issue and fixing it themselves (example: https://github.com/tenstorrent/tt-mlir/pull/899).
Here is the proposed workaround decomposition (copying comment from my commit):
// StableHLO.ReduceOp always removes reduce dimensions from the result shape
// so ideally we would just convert it to TTIR.ReduceOp with keepDim=False.
// Unfortunately we cannot do this because Metal TTNN implementation of
// Reduce doesn't yet support keepDim=False. As a workaround, we convert it
// to combination of TTIR.ReduceOp with keepDim=True + TTIR.ReshapeOp to
// remove the reduce dims so that the rest of the graph is not affected.
// In case when this is not needed (because type converter already promoted
// rank of the op result) then we avoid adding unnecessary Reshape op.
@mrakitaTT, the workaround sounds good to me, but let's do the workaround inside of TTNN dialect passes, not TTIR. @sdjordjevicTT has a document that outlines our plan for workarounds in TTNN, I tagged you in a comment on the doc.
pytest -svv tests/TTIR/test_mnist.py::test_softmax
fails with error:
Always | FATAL | keepdim=False is not supported
The keep_dim flag for reduction Op is always set to false. It might require changing it to true followed by a reshape Op.