Open dgolubovicTT opened 5 days ago
Yeah, I agree with you @dgolubovicTT. We will take a look at it, thanks for reporting! :)
I've made tvm pattern callback that inserts cast before and after each max_pool2d and it solved the issue. This was just to test it locally. I still believe that forge (tvm included) shouldn't be aware of ttnn constraints. There are more operations on ttnn that require bfloat16 on inputs or even require bfloat16 weights. For example embedding op requires emb. weights to be bfloat16.
So now the more general question arises:
@dgolubovicTT I agree with you, you shouldn't worry about the data format constraints from Forge-FE.
Some TTNN ops can automatically do the cast, but some ops fail if the inputs aren't in the specified data format. We have defined the issue on our side to handle such cases: https://github.com/tenstorrent/tt-mlir/issues/1433
For each workaround that we introduce in the tt-mlir stack, we are filing an issue on the metal side to track the resolution. Once the issue is resolved on the metal side, we will remove the workarounds from the compiler code.
This sounds promising. Thanks!
When I run resnet test from forge I get unexpeted error:
loc("max_pool2d_17"("forward":4294967295:3591)): error: 'ttnn.max_pool2d' op ttnn.max_pool2d currently only supports an input type of bfloat16. Recieved 'f32'.
Turns out it is due to assert in verify
mlir::tt::ttnn::Conv2dOp::verify()
added in PR.So If ttnn.max_pool2d only supports bfloat16 we shouldn't just fail compile if its input is float32. We should probably add a cast op to handle this and move on with compile.
@sdjordjevicTT can we prioritize this because it is a blocker for ResNet bringup? fyi @nvukobratTT