Is your feature request related to a problem? Please describe.https://github.com/tenstorrent/pytorch2.0_ttnn/pull/415#discussion_r1857678373 shows that batch normalization should be a C++ composite op. When we implement batch normalization in the compiler, the compiler converts batch normalization to a series of unary elementwise ops and ttnn.reshape that unsqueezes vectors to tensors to match dimensions.
Is your feature request related to a problem? Please describe. https://github.com/tenstorrent/pytorch2.0_ttnn/pull/415#discussion_r1857678373 shows that batch normalization should be a C++ composite op. When we implement batch normalization in the compiler, the compiler converts batch normalization to a series of unary elementwise ops and
ttnn.reshape
that unsqueezes vectors to tensors to match dimensions.Describe the solution you'd like Implement batch normalization as a multiary (ternary?) composite op. https://github.com/tenstorrent/tt-metal/tree/main/ttnn/cpp/ttnn/operations/eltwise
Describe alternatives you've considered
Additional context If we want to implement training batch normalization as well, please take a look at PyTorch decomposition algorithm.