tenstorrent / tt-mlir

Tenstorrent MLIR compiler
https://tenstorrent.github.io/tt-mlir/
Apache License 2.0
78 stars 13 forks source link

Handle max_pool2d input dataformat float32 #1389

Open dgolubovicTT opened 5 days ago

dgolubovicTT commented 5 days ago

When I run resnet test from forge I get unexpeted error:

loc("max_pool2d_17"("forward":4294967295:3591)): error: 'ttnn.max_pool2d' op ttnn.max_pool2d currently only supports an input type of bfloat16. Recieved 'f32'.

Turns out it is due to assert in verify mlir::tt::ttnn::Conv2dOp::verify() added in PR.

So If ttnn.max_pool2d only supports bfloat16 we shouldn't just fail compile if its input is float32. We should probably add a cast op to handle this and move on with compile.

@sdjordjevicTT can we prioritize this because it is a blocker for ResNet bringup? fyi @nvukobratTT

sdjordjevicTT commented 5 days ago

Yeah, I agree with you @dgolubovicTT. We will take a look at it, thanks for reporting! :)

dgolubovicTT commented 3 days ago

I've made tvm pattern callback that inserts cast before and after each max_pool2d and it solved the issue. This was just to test it locally. I still believe that forge (tvm included) shouldn't be aware of ttnn constraints. There are more operations on ttnn that require bfloat16 on inputs or even require bfloat16 weights. For example embedding op requires emb. weights to be bfloat16.

So now the more general question arises:

  1. Default data format in torch is float32 and many of the model weights are in that format
  2. tt-metal obviously works at bfloat16, but lets inputs be in float32 and then implicitly casts them. Sometimes, ttnn ops require that inputs or weights are in bfloat16. So potentially we will have to do back and forth casts throughout the graph. This is ok for the start, but we should require ttnn to accept float32 and cast it implicitly as it does in other ops (ttnn.add).
sdjordjevicTT commented 3 days ago

@dgolubovicTT I agree with you, you shouldn't worry about the data format constraints from Forge-FE.

Some TTNN ops can automatically do the cast, but some ops fail if the inputs aren't in the specified data format. We have defined the issue on our side to handle such cases: https://github.com/tenstorrent/tt-mlir/issues/1433

For each workaround that we introduce in the tt-mlir stack, we are filing an issue on the metal side to track the resolution. Once the issue is resolved on the metal side, we will remove the workarounds from the compiler code.

dgolubovicTT commented 2 days ago

This sounds promising. Thanks!