tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
https://docs.tenstorrent.com/ttnn/latest/index.html
Apache License 2.0
498 stars 83 forks source link

Batch normalization as a composite op #15449

Open jdh8 opened 1 week ago

jdh8 commented 1 week ago

Is your feature request related to a problem? Please describe. https://github.com/tenstorrent/pytorch2.0_ttnn/pull/415#discussion_r1857678373 shows that batch normalization should be a C++ composite op. When we implement batch normalization in the compiler, the compiler converts batch normalization to a series of unary elementwise ops and ttnn.reshape that unsqueezes vectors to tensors to match dimensions.

Describe the solution you'd like Implement batch normalization as a multiary (ternary?) composite op. https://github.com/tenstorrent/tt-metal/tree/main/ttnn/cpp/ttnn/operations/eltwise

Describe alternatives you've considered

  1. pytorch2.0_ttnn#415 already provides a suboptimal alternative, i.e. an implementation in the compiler.
  2. I am not sure about the arity of batch normalization.

Additional context If we want to implement training batch normalization as well, please take a look at PyTorch decomposition algorithm.

eyonland commented 1 day ago

I don't believe we want to do this. At the moment @VirdhatchaniKN is working on #12253. Is there something I'm missing here?