Open jdh8 opened 2 months ago
@jdh8 need more details so we can pass the info to the Op team.
Can you please specify individual problems with each op? Ideally we should have single ticket for single issue, with issue categorized
Op Category:
Issue Category:
Specific input: [params]. Link to an xfail test.
@jdh8 please provide more details so we can fire individual tickets to Op lead
@jdh8 For reshape
your test (tests/lowering/tensor_manipulation/test_reshape.py
) is not actually dispatching aten.reshape
to compiler, but rather aten.view
.
Your test module:
def forward(self, x, new_shape):
return torch.reshape(x, new_shape)
is passed to compiler as
def forward(self, arg0_1):
view = torch.ops.aten.view.default(arg0_1, [21, 5]); arg0_1 = None
return (view,)
which will yield the same result, but without using repeat
.
I don't see any code for processing aten OPs reshape
, concat
, and repeat
in torch_ttnn/passes/lowering/to_tt_pass.py
. That would explain why they are not being lowered to ttnn during compilation.
@boris-drazic, I once removed conversion for these ops to keep the size of tenstorrent/pytorch2.0_ttnn#54 reviewable. Now I make separate PR for each op.
@jdh8 for each sub issue can you give us a ttnn unit test thats failing. CC @ayerofieiev-tt
@tarafdarTT, I've made each sub-issue as a PR, so we can experiment on tests without affecting other ops. In each PR, I 'unlocked' some tests by removing @pytest-mark-xfail
, and these are the failing tests.
concat
tenstorrent/pytorch2.0_ttnn#188expand
tenstorrent/pytorch2.0_ttnn#146reshape
tenstorrent/pytorch2.0_ttnn#190@jdh8 oh okay, do you need anything from me (or the TM team ) right now or its that after you lower to ttnn ?
The following ops refuse to lower to
ttnn
and stay inaten
:concat
tenstorrent/pytorch2.0_ttnn#188expand
tenstorrent/pytorch2.0_ttnn#146repeat
tenstorrent/pytorch2.0_ttnn#189reshape
tenstorrent/pytorch2.0_ttnn#190