tenstorrent / pytorch2.0_ttnn

⭐️ TTNN Compiler for PyTorch 2.0 ⭐️ It enables running PyTorch2.0 models on Tenstorrent hardware
https://tenstorrent.github.io/tt-metal/latest/ttnn/
16 stars 5 forks source link

Failed to lower `scalar [*/] tensor` #110

Open jdh8 opened 2 weeks ago

jdh8 commented 2 weeks ago
          Out of the 8 cases ({+, -, *, /} × { LHS, RHS }), the following 2 cases does not work:

I'm still investigating which lines of code drive such conversions. In to_tt_pass.py, I only see conversions for aten.*.Tensor but not aten.*.Scalar. There is also a hook for relational ops (relational_scalar_ops) but not for arithmetic ops.

_Originally posted by @jdh8 in https://github.com/tenstorrent/pytorch2.0_ttnn/issues/100#issuecomment-2315009036_

jdh8 commented 2 weeks ago

@kevinwuTT, do you know where the conversion of tensor-scalar ops (such as mul.Scalar) take place? I tried git grep -F '.Scalar' **/*.py but was out of luck.

tests/lowering/misc/test_if.py:    assert target_0.count(torch.ops.aten.gt.Scalar) == 1
torch_ttnn/passes/lowering/to_tt_pass.py:    torch.ops.aten.eq.Scalar: ttnn.eq,
torch_ttnn/passes/lowering/to_tt_pass.py:    torch.ops.aten.lt.Scalar: ttnn.lt,
torch_ttnn/passes/lowering/to_tt_pass.py:            if node.target == torch.ops.aten.rsub.Scalar:
kevinwuTT commented 2 weeks ago

If they're not found, then I don't think they're implemented yet. Also, sometimes the following happens before we get our passes:

out = tensor * 3

Intuitively, you would think this gets lowered to:

out = aten.mul.Scalar(tensor, 3)

But instead, torch lowers this into something like:

fill = aten.fill([], 3)
out = aten.mul.Tensor(tensor, fill)

I'm not sure if there's a way to change this behavior, but this might be a reason why some scalar ops are being lowered even though there are no direct scalar conversions in our passes. Is this what you're seeing?

jdh8 commented 2 weeks ago

Yeah, I see tensor-scalar ops lowered anyhow despite of lacking an explicit conversion. PyTorch probably converts the scalars to tensors in advance.

Another unexpected lowering was proved in #114.

torch.nn.functional.max_pool2d(x)

PyTorch lowers this expression into:

aten.max_pool2d_with_indices.default(x)[0]

I don't have a good idea for lowering __getitem__ (i.e. []) yet.... :frowning_face:

Also, tt-metal now lacks an op that simultaneously computes maxpool and its corresponding "argmax".