tenstorrent / pytorch2.0_ttnn

⭐️ TTNN Compiler for PyTorch 2.0 ⭐️ It enables running PyTorch2.0 models on Tenstorrent hardware
https://tenstorrent.github.io/tt-metal/latest/ttnn/
25 stars 6 forks source link

Failed to lower `aten.linear` to `ttnn.linear` #66

Open jdh8 opened 3 months ago

jdh8 commented 3 months ago

Test case: see 701b5c09e95347049280b2bfa59eb2ae87fafcab

In to_tt_pass.py, we try to convert aten.linear to ttnn.linear. However, we don't see any ttnn.linear in the resulting graph but ttnn.permute and ttnn.matmul instead.

ayerofieiev-tt commented 3 months ago

@kevinwuTT I am pretty sure I remember seeing linear lowered as permute/matmul. Maybe we got two conversions and one overrides the other?

kevinwuTT commented 3 months ago

linear was not lowering previously, so it should be new.

jdh8 commented 3 months ago

https://github.com/tenstorrent/pytorch2.0_ttnn/blob/f87d7b3ec2947bb93e9d4b48ecf2c20d45edc369/torch_ttnn/passes/lowering/to_tt_pass.py#L71-L72

The conversion for aten.linear is here, but it is somehow unused.

ayerofieiev-tt commented 2 months ago

@jdh8 I think this is the guilty pattern here https://github.com/tenstorrent/pytorch2.0_ttnn/blob/ee0b425ba2473726ae5cfa8692f3c4920a51cfb0/torch_ttnn/patterns/linear.py#L4

jdh8 commented 2 months ago

Not this one. The graph remains the same (permute + matmul) after I remove the whole file (linear.py).

jdh8 commented 2 months ago

The only explanation is that PyTorch does not lower torch.nn.functional.linear to aten.linear but to aten.permute + aten.matmul instead. Whether we should convert this combination to ttnn.linear is a new question.

ayerofieiev-tt commented 1 month ago

@jdh8 , yes, I think its a desirable fusion