Open jdh8 opened 2 weeks ago
@kevinwuTT, do you know where the conversion of tensor-scalar ops (such as mul.Scalar
) take place? I tried git grep -F '.Scalar' **/*.py
but was out of luck.
tests/lowering/misc/test_if.py: assert target_0.count(torch.ops.aten.gt.Scalar) == 1
torch_ttnn/passes/lowering/to_tt_pass.py: torch.ops.aten.eq.Scalar: ttnn.eq,
torch_ttnn/passes/lowering/to_tt_pass.py: torch.ops.aten.lt.Scalar: ttnn.lt,
torch_ttnn/passes/lowering/to_tt_pass.py: if node.target == torch.ops.aten.rsub.Scalar:
If they're not found, then I don't think they're implemented yet. Also, sometimes the following happens before we get our passes:
out = tensor * 3
Intuitively, you would think this gets lowered to:
out = aten.mul.Scalar(tensor, 3)
But instead, torch lowers this into something like:
fill = aten.fill([], 3)
out = aten.mul.Tensor(tensor, fill)
I'm not sure if there's a way to change this behavior, but this might be a reason why some scalar ops are being lowered even though there are no direct scalar conversions in our passes. Is this what you're seeing?
Yeah, I see tensor-scalar ops lowered anyhow despite of lacking an explicit conversion. PyTorch probably converts the scalars to tensors in advance.
Another unexpected lowering was proved in #114.
torch.nn.functional.max_pool2d(x)
PyTorch lowers this expression into:
aten.max_pool2d_with_indices.default(x)[0]
I don't have a good idea for lowering __getitem__
(i.e. []
) yet.... :frowning_face:
Also, tt-metal
now lacks an op that simultaneously computes maxpool and its corresponding "argmax".
scalar * tensor
fails to compilescalar / tensor
gives incorrect results by computingtensor / scalar
insteadI'm still investigating which lines of code drive such conversions. In
to_tt_pass.py
, I only see conversions foraten.*.Tensor
but notaten.*.Scalar
. There is also a hook for relational ops (relational_scalar_ops
) but not for arithmetic ops._Originally posted by @jdh8 in https://github.com/tenstorrent/pytorch2.0_ttnn/issues/100#issuecomment-2315009036_