tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
398 stars 50 forks source link

Support `aten.max_pool2d_with_indices` #12099

Open jdh8 opened 3 weeks ago

jdh8 commented 3 weeks ago

As we find out in tenstorrent/pytorch2.0_ttnn#114, PyTorch converts torch.nn.functional.max_pool2d to aten.max_pool2d_with_indices even in indices are not requested. For the sake of completeness and a simple structure, I suggest making a version of max_pool2d that simultaneously outputs indices.

ayerofieiev-tt commented 3 weeks ago

@jdh8 to clarify, this does not block compiler and you were able to lower to ttnn.max_pool2d for the case where indices are not used?

jdh8 commented 3 weeks ago

I'm consulting @kevinwuTT for tech support on the workaround. https://github.com/tenstorrent/pytorch2.0_ttnn/pull/114#discussion_r1738850006

I hope we can find a way to lower an expression involving __getitem__

aten.max_pool2d_with_indices.default(x)[0]
davorchap commented 3 weeks ago

I'm consulting @kevinwuTT for tech support on the workaround. tenstorrent/pytorch2.0_ttnn#114 (comment)

I hope we can find a way to lower an expression involving __getitem__

aten.max_pool2d_with_indices.default(x)[0]

Indicies won’t ready for a while (needed only for training ) so go with a workaround