Open jdh8 opened 3 months ago
@jdh8 to clarify, this does not block compiler and you were able to lower to ttnn.max_pool2d
for the case where indices are not used?
I'm consulting @kevinwuTT for tech support on the workaround. https://github.com/tenstorrent/pytorch2.0_ttnn/pull/114#discussion_r1738850006
I hope we can find a way to lower an expression involving __getitem__
aten.max_pool2d_with_indices.default(x)[0]
I'm consulting @kevinwuTT for tech support on the workaround. tenstorrent/pytorch2.0_ttnn#114 (comment)
I hope we can find a way to lower an expression involving
__getitem__
aten.max_pool2d_with_indices.default(x)[0]
Indicies won’t ready for a while (needed only for training ) so go with a workaround
Removed from sprint, only needed for training. So far focusing more on inference
As we find out in tenstorrent/pytorch2.0_ttnn#114, PyTorch converts
torch.nn.functional.max_pool2d
toaten.max_pool2d_with_indices
even in indices are not requested. For the sake of completeness and a simple structure, I suggest making a version ofmax_pool2d
that simultaneously outputs indices.