Open jdh8 opened 3 months ago
@kevinwuTT I am pretty sure I remember seeing linear lowered as permute/matmul. Maybe we got two conversions and one overrides the other?
linear was not lowering previously, so it should be new.
The conversion for aten.linear
is here, but it is somehow unused.
@jdh8 I think this is the guilty pattern here https://github.com/tenstorrent/pytorch2.0_ttnn/blob/ee0b425ba2473726ae5cfa8692f3c4920a51cfb0/torch_ttnn/patterns/linear.py#L4
Not this one. The graph remains the same (permute
+ matmul
) after I remove the whole file (linear.py
).
The only explanation is that PyTorch does not lower torch.nn.functional.linear
to aten.linear
but to aten.permute
+ aten.matmul
instead. Whether we should convert this combination to ttnn.linear
is a new question.
@jdh8 , yes, I think its a desirable fusion
Test case: see 701b5c09e95347049280b2bfa59eb2ae87fafcab
In
to_tt_pass.py
, we try to convertaten.linear
tottnn.linear
. However, we don't see anyttnn.linear
in the resulting graph butttnn.permute
andttnn.matmul
instead.