tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
https://docs.tenstorrent.com/ttnn/latest/index.html
Apache License 2.0
483 stars 78 forks source link

Broadcasting issues about pointwise binary ops #12852

Open jdh8 opened 3 months ago

jdh8 commented 3 months ago

I tested various broadcasting settings against all pointwise binary ops, and it produced interesting results. I tested both uni- and bidirectional broadcasting by using the following input shapes:

@pytest.mark.parametrize(
    "input_shapes",
    (((32, 32), (32, 32)), ((64,), (32, 64)), ((64, 32), (64, 1)), ((64, 1), (1, 64))),
)

I also tested 0-D cases such as (), (32, 32). This caused a floating point exception (core dumped) and halted the whole test set.

Arithmetic ops and xlogy

Unidirectional broadcasting passed: ((64,), (32, 64)), ((64, 32), (64, 1)).

Bidirectional broadcasting failed: ((64, 1), (1, 64))). It somehow passed compilation but produced incorrect results.

E       assert False
E        +  where False = <built-in method allclose of type object at 0x7f768a48c8c0>(tensor([[3., 4., 4.,  ..., 2., 5., 4.],\n        [5., 6., 6.,  ..., 4., 7., 6.],\n        [3., 4., 4.,  ..., 2., 5., 4.]....., 2., 5., 4.],\n        [4., 5., 5.,  ..., 3., 6., 5.],\n        [6., 7., 7.,  ..., 5., 8., 7.]], dtype=torch.bfloat16), TorchTensor([[3., 3., 3.,  ..., 4., 1., 3.],\n             [5., 3., 3.,  ..., 4., 1., 3.],\n             [3., 3., 3.,  ...., 0.],\n             [0., 0., 0.,  ..., 0., 0., 0.],\n             [0., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.bfloat16))
E        +    where <built-in method allclose of type object at 0x7f768a48c8c0> = torch.allclose

tests/lowering/eltwise/binary/test_add.py:35: AssertionError

Comparison ops

All cases involving broadcasting failed. Only inputs with the same sizes passed.

Logical ops

All cases failed, including the ones with identical sizes. These tests failed for inconsistent reasons. They either failed for broadcasting or RuntimeError: Bool did not match BFloat16.

====================================================================================== short test summary info ======================================================================================
FAILED tests/lowering/eltwise/binary/test_logical_and.py::test_logical_and[input_shapes0] - RuntimeError: Bool did not match BFloat16
FAILED tests/lowering/eltwise/binary/test_logical_and.py::test_logical_and[input_shapes1] - RuntimeError: TT_THROW @ ../ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp:21: tt::exception
FAILED tests/lowering/eltwise/binary/test_logical_and.py::test_logical_and[input_shapes2] - RuntimeError: TT_THROW @ ../ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp:21: tt::exception
FAILED tests/lowering/eltwise/binary/test_logical_and.py::test_logical_and[input_shapes3] - RuntimeError: TT_THROW @ ../ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp:21: tt::exception
FAILED tests/lowering/eltwise/binary/test_logical_or.py::test_logical_or[input_shapes0] - RuntimeError: Bool did not match BFloat16
FAILED tests/lowering/eltwise/binary/test_logical_or.py::test_logical_or[input_shapes1] - RuntimeError: TT_THROW @ ../ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp:21: tt::exception
FAILED tests/lowering/eltwise/binary/test_logical_or.py::test_logical_or[input_shapes2] - RuntimeError: TT_THROW @ ../ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp:21: tt::exception
FAILED tests/lowering/eltwise/binary/test_logical_or.py::test_logical_or[input_shapes3] - RuntimeError: TT_THROW @ ../ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp:21: tt::exception
FAILED tests/lowering/eltwise/binary/test_logical_xor.py::test_logical_xor[input_shapes0] - RuntimeError: Bool did not match BFloat16
FAILED tests/lowering/eltwise/binary/test_logical_xor.py::test_logical_xor[input_shapes1] - RuntimeError: Bool did not match BFloat16
FAILED tests/lowering/eltwise/binary/test_logical_xor.py::test_logical_xor[input_shapes2] - RuntimeError: Bool did not match BFloat16
FAILED tests/lowering/eltwise/binary/test_logical_xor.py::test_logical_xor[input_shapes3] - RuntimeError: Bool did not match BFloat16

_Originally posted by @jdh8 in https://github.com/tenstorrent/pytorch2.0_ttnn/issues/54#issuecomment-2298191687_

Edited on 2024-08-22 03:00Z

jdh8 commented 2 months ago

For comparison ops, innermost size in [1, 31] is no longer considered valid paging in tt-metal https://github.com/tenstorrent/tt-metal/blob/6fbf494df30b6e228847bcd94574ca7a9a832e0d/tt_metal/impl/buffers/buffer.cpp#L25-L48

We call this "strict" page size for now. https://github.com/tenstorrent/pytorch2.0_ttnn/blob/a5ce2a677f6ea9c6ba704c21db1c771bff03ef8a/torch_ttnn/passes/lowering/to_tt_pass.py#L320-L323